Swin Transformer: Windows of Attention
The popularity of transformer models in natural language processing inspired the computer vision (CV) community to adapt the transformer model for CV tasks. Vision Transformer (ViT) was one of the earliest attempts to adapt transformer for CV objectives. ViT splits the input image into equal size patches which are further treated as equivalent to word embedding in a sentence. Although ViT achieved promising performance on the image classification task, the self-attention operation in ViT has a quadratic computational complexity to input image size which makes it unsuitable for high-resolution images. Also, CV tasks such as image segmentation are dense pixel-level prediction tasks that require finer image patches. Swin Transformer overcomes these challenges with a unique Shifting Windows approach with computational complexity that scales linearly with the image size.
Window Attention
Instead of computing self-attention on all the patches in an input image, window attention is limited to include patches that correspond to a window of predefined shape. An input image is split into multiple non-overlapping windows of equal size for attention computation. For example, in the window attention figure included below, the image is split into (8 , 8) patches and 4 non-overlapping windows each comprising of (4 , 4) patches are shown. The computational complexity of the attention mechanism used by the ViT model would be of order O(64 x 64) since there are 64 patches in the image, whereas the computational complexity in the case of window attention would be of the order O(4 x16 x16) because the image has 4 windows each with 16 patches. To generalize, complexity of dot product in self-attention(SA) and window self-attention(WSA) for an image with (h ,w) patches split into windows with each window containing (M ,M) patches can be written as:
Shifted windows attention
Although window-based self-attention is computationally efficient compared to regular self-attention but WSA is limited to patches within the window, as patches corresponding to different windows do not interact. To overcome this limitation, windows are shifted in consecutive layers by M/2 and M/2 patches horizontally and vertically respectively. In the following figure windows 1,2,3,and 4 are shifted by 2 patches as M=4 in this case. Observe that this shifting operation resulted in more windows in the subsequent layer compared to the current layer, and also the window sizes are not consistent. A straightforward approach to compute attention in the layer l+1 is to pad the smaller windows to size (M ,M), but this approach increases the computational complexity of attention as the shifting results in more windows, increased from (h/M , w/M) to ([h/M+1] ,[w/M+1]). It can be seen in the figure, that going from layer l to layer l+1 number of windows increased from (2 x 2 to 3 x 3) which adversely affects the attention computation.
To preserve the computational efficiency authors proposed an effective batch computation approach by cyclic shifting(rolling) the additional windows, then perform attention computation on cyclic shifted configuration, and then shifting the windows back to the original locations. The cyclic shifting is illustrated in the following figure:
Patches were color-coded based on the window they belong to in layer l+1. Color coding was only applied to the additional windows that come into existence due to the window shift operation performed on layer l+1. Windows 2, 3, and 4 were augmented to include the additional windows. This augmentation reduced the number of windows from 9 to 4. As the newly formed windows comprise sub-windows that contain image features that were not adjacent in the original image, a masking mechanism is employed to limit the attention computation limited to the sub-window. The sub-windows are rolled back to the original location after the attention computation. The cyclic shifting ensured the same number of windows as the regular window partition, thus efficient.
Network Architecture
The Swin Transformer model comprises 4 stages that progressively reduce the input image resolution similar to the convolutional neural networks. Each stage contains multiple Swin transformer blocks depending on the variant. The figure above shows Swin-T with 2,2,6, and 2 blocks for the four stages respectively and the embedding dimension C is set to 96. Similarly Swin-S, Swin-B, and Swin-L have (2,2,18,2) structure but they differ in the size of embedding dimension C set to 96, 128, and 192 respectively. It is important to note that the Swin Transformer blocks do not alter the shape of their corresponding inputs i.e input and output of Swin Transformer block has same shape.
Swin Transformer Blocks
Swin Transformer consists of two types of blocks that differ in the attention mechanism employed. At every stage, the first transformer block uses regular multi-head window attention(W-MSA), then following blocks alternate between shifted window multi-head attention(SW-MSA) and W-MSA. All the Swin Transformers blocks use a two-layer multi-layer perceptron(MLP), layer norm layer (LN), and residual connections similar to regular transformer blocks.
Stages
The input RGB image of shape (H , W , 3) is broken into non-overlapping (4 , 4) patches and flattened using a patch partition block. Each flattened patch has a length of 48 (4x4x3=48). In stage 1 a linear embedding layer projects each patch to an embedding dimension C. This linear projection is processed through multiple Swin Transformer blocks of stage1. Stages 2,3, and 4 have a patch merging layer that concatenates each group of adjacent (2 , 2) patches then linearly projects them to a lower resolution followed by multiple transformer blocks.
Input image size is assumed to be (224 , 224 , 3) for Swin-T to ensure smooth size reduction on multiple stages. Window size M is chosen to be 7 and embedding dimension C to be 96. To summarize, inputs and output shapes of multiple stages in Swin Transformer for an input image of size 224:
Patch partition: (224 , 224 , 3) to (56 , 56 , 48)
stage 1: (56 , 56 , 48) to (56 , 56 , 96)
stage 2: (56 , 56 , 96) to (28 , 28 , 192)
stage 3: (28 , 28 , 192) to (14 ,14 , 384)
stage 4: (14 , 14 , 384) to (7 , 7 , 768)
The output of stage 4 has a size equal to the predefined window size M =7. An adaptive average pooling layer followed by a layer norm is used to convert the input image to a 1-d representation of size 768. Then a linear layer that acts as a classification head that projects it to the number of classes.
Relative position bias
The authors found significant improvement in model performance by including relative position bias to each head when computing similarity.The query, key, and value are of shape (M² , d) where M² is the total number of patches in a window and d is the embedding dimension. The scaled dot product between the query and the key has a shape (M² , M²).
The relative position along each axis can take values between [-M+1, M-1], relative position can take (2M-1) values. The values of B are taken from a smaller bias matrix of size ((2M-1) , (2M-1)). A relative position bias B of shape (M² , M²) is added to the scaled dot product between the query and the key.
Swin Transformer model performance
In recent years, computer vision is undergoing a paradigm shift from CNN models towards transformers. Swin Transformer was introduced as a general-purpose backbone for various computer vision tasks such as classification, object detection, and semantic segmentation.
Image Classification
Dataset: ImageNet-1K; training:1.28M images; validation: 50K images; classes:1000
Swin-B achieved 84.5% top-1 accuracy and Swin-L when pretrained on ImageNet-22K and fine-tuned on ImageNet-1K achieved 87.3 top-1 accuracy.
Object Detection
Dataset: COCO; training: 118K images; validation: 5K images; test: 20K images.
Swin-L obtained 58.7 AP box and 51.1 AP mask performance on COCO objection and instance segmentation.
Semantic Segmentation
Datasets: ADE20K; training: 20K images; validation: 2K images; testing: 3K images; 150 semantic categories.
Swin-L pretrained on ImageNet-22K achieved 53.5 val mIoU, 62.8 test score on semantic segmentation on AED20K dataset.
Swin Transformer surpassed state-of-the-art results on various CV tasks with the novel shifted window attention mechanism. It was able to achieve this extraordinary performance by having a self-attention mechanism with linear complexity wrt to image size without losing global picture of the image.
References
Z. Liu et al. Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. 2021. Link: https://arxiv.org/abs/2103.14030.
Dosovitskiy, Alexey, et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020). Link: https://arxiv.org/abs/2010.11929
I would like to thank you for reading the article! I found the shifted window attention mechanism exciting, which motivated me to write this article. There is scope for a better explanation regarding the relative position bias concept used in this paper and many others. I will try to write a separate article specific to relative position bias in the future.