Swin Transformer论文精读-学习笔记
原论文链接:https://arxiv.org/abs/2103.14030
原视频链接:https://www.bilibili.com/video/BV13L4y1475U
前言
Swin Transformer,也就是Hierarchical Shifted Window Vision Transformer,是一种基于Transformer架构的计算机视觉模型,旨在解决传统Transformer在处理高分辨率图像时计算成本过高的问题。Swin Transformer通过引入窗口注意力机制和滑动窗口策略,有效地降低了计算复杂度,同时保持了强大的特征表达能力。Swin Transformer还在ViT的基础上,引入了类似传统CNN的分块的层级式结构,从而能够提取图像在不同尺度上的特征。此外,Swin Transformer在CV的绝大多数任务上都取得了优异的表现,再次证明了Transformer在CV领域的可行性与有效性,成为近年来计算机视觉领域的重要研究方向之一。
Swin Transformer论文精读
- 将传统Transformer应用到视觉任务会面临哪些挑战?
- 首先是目标的尺度不统一,比如同样的“汽车”对象,在不同图片里的大小可能差距很大;然后是图像的高分辨率问题,过高的像素会导致输入Transformer的序列长度过长,难以计算。
Swin Transformer的创作动机,是为了证明Transformer在CV领域可以作为一个通用的骨干网络,能够适用于所有的CV任务,包括分类,检测,分割,视频等,而不仅仅是图像分类任务。

- ViT在CV任务中已经表现很好了,为什么还要使用Swin Transformer呢?
首先,ViT的一个明显缺陷(这也是Transformer的固有属性)是,由于它在patch embedding后使用多层连续的transformer编码器块,因此数据在经过这些块时,序列长度是保持不变的,就等于一开始划分出的patches数量,这就导致对于底层的数据来说,每个patch太大了,无法捕捉到足够细粒度的特征(如物体边缘),而对于顶层的数据来说,每个patch又太小了,这就导致patches数量很多,对于高分辨率图像来说,计算成本过高;对于图像中不同尺度的对象,ViT尝试使用固定大小的patch进行捕捉,但这显然不够灵活,无法适应不同尺度的对象,而Swin Transformer通过引入层级式结构和滑动窗口机制,能够像CNN一样不断通过下采样提取不同尺度的特征,更好地适应不同尺度的对象,动态调整每一层transformer块的输入序列长度,从而更有效地捕捉图像中的特征。此外,ViT每个注意力操作都是在全局范围的patches上进行的,这就导致了计算成本过高,而Swin Transformer通过引入窗口注意力机制,将计算复杂度从ViT的O(n^2)降低到O(n),使得Swin Transformer能够处理更高分辨率的图像,同时保持了可接受的计算成本。
- 为了解决ViT成本高的问题,为什么不直接把一开始的patch变大一点?
首先,过大的patch会导致每个patch包含的信息过多,无法捕捉到图像细节;其次,由于不同对象尺度不同,大patch将难以识别那些较小的对象;最后,ViT每一层都是固定大小的patch范围,导致缺乏从“局部细节”到“整体结构”的渐进过程。
Swin Transformer与传统CNN相比有何优势?
Swin Transformer虽然在整体架构和设计思路上很像CNN,但因为使用了注意力机制,Swin有三个显著的优势:1.CNN卷积核的参数是固定的。不管输入是什么,卷积核都用同样的权重去“套”它,是内容无关的,而Swin的注意力权重是根据输入内容动态计算出来的,是内容自适应的;2.CNN的感受野是受限的。每一层卷积只能看到周围有限范围内的像素,想要看到整张图,必须堆叠很多层,而Swin虽然也有窗口限制,但即使在最底层,一个窗口内的所有 Patch 都是一步到位互相连接的。通过移位窗口(Shifted Window),它扩大感受野的速度比 CNN 快得多,且联系更直接;3.CNN(如ResNet)存在明显的性能瓶颈,到了后期,无论你给它多少数据,它的性能提升非常缓慢,容易饱和,而Swin继承了 Transformer 的优良基因——极其抗饱和。只要你有足够大的数据集,Swin 的精度上限比 ResNet 高得多。
为什么Swin的计算复杂度是O(n)而ViT是O(n^2)?
- 简单来讲,ViT在每一层只做一次注意力汇聚,而传给该注意力的序列长度等于patches数量,也就是与图像尺寸成正比,因此每一层的计算复杂度是序列长度的平方,即O(n^2);而Swin在每一层都要计算多个窗口内的注意力,每个窗口内的序列长度是固定的,与图像尺寸无关(此处应用了CNN的locality的归纳偏置,一个对象的不同部位往往是聚集在一起的,不会分散到各处,因此没必要用全局注意力),而窗口数量是与图像尺寸成正比的,因此每一层的计算复杂度是O(n)。
当然,ViT的重大意义在于它没有对Transformer编码器做任何更改,仅仅是通过数据预处理就让Transformer能够处理图像数据,极大的揭示了Transformer在多模态模型的可能性;而Swin针对CV任务对架构做了针对性的更改,虽然牺牲了一定的通用性,但在CV任务上取得了更好的性能。
- 模型架构

数据处理流程:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16输入图像(H,W,3)
-[patch partition]-> ((H/4),(W/4),48) (4x4x3=48)
-[linear embedding]-> ((H/4),(W/4),C) (C是模型的维度,通常是96,192,384,768等)
-[Transformer编码器]-> ((H/4),(W/4),C)
-[patch merging]-> ((H/8),(W/8),2C)
-[Transformer编码器]-> ((H/8),(W/8),2C)
-[patch merging]-> ((H/16),(W/16),4C)
-[Transformer编码器]-> ((H/16),(W/16),4C)
-[patch merging]-> ((H/32),(W/32),8C)
-[Transformer编码器]-> ((H/32),(W/32),8C)
->输出特征图((H/32),(W/32),8C)
以上是模型的主干部分,可以根据下游任务需要接入不同的头部,例如分类头,检测头,分割头等。Swin Transformer的主干部分是一个层级式结构,由四个阶段组成,每个阶段包含多个Transformer编码器块和一个patch merging层。每个阶段的输入输出尺寸逐渐减小,特征维度逐渐增加,从而能够提取不同尺度的特征。
例如,如果要进行图片分类,可以在模型的主干部分后面加一个全局平均池化层和一个线性分类头;如果要进行目标检测,可以在模型的主干部分后面加一个FPN(Feature Pyramid Network)和一个检测头;如果要进行语义分割,可以在模型的主干部分后面加一个上采样模块和一个分割头等。
patch merging层的作用是对特征进行下采样和维度增加,类似于CNN中的池化层。它的处理流程:首先,按位置将输入特征图划分为不重叠的2x2块;然后,将每个块在特征维度上进行拼接,得到一个新的大小为原来一半的特征图,其中每个位置的特征维度是原来的4倍;最后,通过一个1x1卷积将特征维度减少为原来的2倍,从而实现下采样。

- 窗口自注意力:为了解决图片分辨率过高导致transformer序列过长的问题,Swin Transformer引入了窗口自注意力机制,即在每层的第一个Transformer块中,先将输入特征图划分为不重叠的固定大小的窗口(通常一个窗口包含7*7个patch),然后在每个窗口内独立地计算自注意力,从而将计算复杂度从O(n^2)降低到O(n)。

- 移动窗口:然而,如果仅仅在每个小窗口内进行注意力汇聚,会导致不同窗口间无法进行通信,也就无法捕捉到跨窗口的全局信息。为了解决这个问题,Swin Transformer在每层的窗口注意力之后引入了一个移动窗口Transformer块,即将特征图中的所有窗口进行偏移(通常斜向移动半个窗口边长的距离),使得每个窗口与之前的其他窗口有一部分重叠,从而实现跨窗口的信息交流。通过这种方式,Swin Transformer能够捕捉到图像中的全局信息。

Transformer块的架构图:

- Transformer块中的数据处理流程:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16输入特征图(H,W,C)
-[LayerNorm]-> (H,W,C)
-[Window Multi-Head Self-Attention]-> (H,W,C)
-[DropPath]-> (H,W,C)
-[LayerNorm]-> (H,W,C)
-[MLP]-> (H,W,C)
-[DropPath]-> (H,W,C)
-[LayerNorm]-> (H,W,C)
-[Shifted Window Multi-Head Self-Attention]-> (H,W,C)
-[DropPath]-> (H,W,C)
-[LayerNorm]-> (H,W,C)
-[MLP]-> (H,W,C)
-[DropPath]-> (H,W,C)
->输出特征图(H,W,C)
- 移动窗口后,原来的4个大小相同的窗口变成了9个大小不同的窗口(这里的‘4’和‘9’只是为了贴合示意图中的数量,不代表真实数量),无法放到同一个batch中并行处理,且窗口的数量也增加了一倍多,影响计算效率,怎么解决?
- 为了解决这个问题,Swin Transformer引入了一种循环移位机制,如下图。

首先,将边角处的窗口进行循环移位(例如从左上移到右下),拼成与中间大窗口的大小一致的若干个窗口,且这些窗口的总数量比之前少。现在,我们得到了数量更少(与移动窗口前一致)且大小相同的窗口,可以放到同一个batch中并行处理。然后,通过带掩码的注意力汇聚在所有这些窗口上计算注意力,掩码的作用是防止本来不属于同一窗口,但循环移位后被拼成同一个窗口的不同窗口之间的信息交流,从而保持了移动窗口的效果。最后,将这些窗口再循环移位回原来的位置,恢复到特征图原来的形状,继续进行后续的处理。
关于使用的掩码,简单来讲,在“合并”后的窗口内进行注意力操作时,由于该“合并”窗口可能是由若干个不相关的窗口拼接而成的,那么通过QK相乘得到的attention scores矩阵中,那些来自同一窗口的token之间的attention score是正常计算的,而那些来自不同窗口的token之间的attention score也被计算了出来,但不是我们想要的,因此需要设计一个特殊形状的掩码矩阵,将那些来自不同窗口的token之间的attention score设置为负无穷,这样在softmax后它们的权重就变成了0,从而实现了移动窗口的效果。mask矩阵的具体形状和原窗口在特征图中的位置有关,对于需要忽略的位置,mask矩阵的元素为负无穷(通常设为-100),对于需要保留的位置,mask矩阵的元素为0,然后将mask与attention scores矩阵相加,就得到了带掩码的attention scores矩阵。mask矩阵的示意图如下:


实际上,循环移位这些操作的目的仅仅是为了在移动窗口后能够并行计算注意力,如果不考虑计算效率,完全可以直接在移动窗口后的特征图上进行注意力计算,而不需要进行循环移位和掩码操作。
后记
Swin Transformer借了CNN的形,却用了Transformer的神,有效的融合了CNN针对图像处理的归纳偏置与Transformer的全局建模能力,在CV领域取得了优异的表现。如果说ViT证明了Transformer在CV领域可行,那么Swin Transformer则进一步发展出了一种与视觉任务深度结合的Transformer骨干架构,是CV领域Transformer应用的一个重要里程碑。