Skip to content

残差网络

ResNet [2015]

He K , Zhang X , Ren S ,et al.Deep Residual Learning for Image Recognition[J].IEEE, 2016.DOI:10.1109/CVPR.2016.90.

官方 PyTorch 实现:torchvision.models.resnet

residual

图:一个正常块(左图)和一个残差块(右图)

residual

图:网络架构比较

ResNet 要解决什么问题

在 ResNet 出现之前,一个很自然的想法是:

网络越深,表达能力越强,效果应该越好。

但实验里会遇到一个反直觉的问题:当网络越来越深时,训练误差反而可能变高。

注意,这里说的不是过拟合。过拟合通常表现为:

  • 训练集效果很好;
  • 测试集效果变差。

而 ResNet 论文里关心的 degradation problem 是:

  • 网络加深以后;
  • 训练集误差也变差;
  • 说明模型连训练数据都没学好。

这就说明问题不只是模型容量,而是 深层网络本身不好优化

ResNet 的核心做法就是给网络加一条“捷径”:

y=F(x)+x

其中:

  • x 是输入;
  • F(x) 是卷积层真正要学习的变化;
  • x 这条路直接跳过去,叫做 shortcut / skip connection;
  • 最后把两条路相加。

TIP

普通网络希望每一层直接学到目标映射:

H(x)

ResNet 不让卷积层直接学 H(x),而是让它学:

F(x)=H(x)x

最后再通过:

H(x)=F(x)+x

把结果还原出来。

这就是 residual learning,也就是“残差学习”。

从梯度角度看

残差块写成:

y=F(x)+x

如果损失函数是 L,反向传播时有:

Lx=Ly(F(x)x+I)

这里的 I 来自 shortcut 那条直接相加的路径。

它的直观意义是:

梯度不一定非要穿过所有卷积层,至少还有一条直接路径可以往前传。

这就是为什么残差连接能缓解深层网络中的优化困难。

代码

python
resnet18  = ResNet(BasicBlock, [2, 2, 2, 2])
resnet34  = ResNet(BasicBlock, [3, 4, 6, 3])
resnet50  = ResNet(Bottleneck, [3, 4, 6, 3])
resnet101 = ResNet(Bottleneck, [3, 4, 23, 3])
resnet152 = ResNet(Bottleneck, [3, 8, 36, 3])

辅助卷积函数

python
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )
  • bias=False 因为卷积后面通常接 BatchNorm2d。BN 有自己的可学习仿射参数,所以卷积 bias 通常可以省掉。
  • padding=dilation 普通卷积时 dilation=1,所以 padding=1,这样 3×3 卷积在 stride=1 时不会改变特征图尺寸。空洞卷积时,padding 会跟着 dilation 变大。
  • groups 默认是 1。普通 ResNet 用 groups=1,ResNeXt 会用更大的 groups。
python
def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=1,
        stride=stride,
        bias=False,
    )

它主要有两个作用:

  • 在 Bottleneck 里面改变通道数;
  • 在 downsample 分支里对 shortcut 做投影,让 shortcut 的形状和主分支输出对齐。

1×1 卷积主要负责“换通道”,有时候也通过 stride 顺便改变空间尺寸。

BasicBlock:ResNet-18 / ResNet-34 的残差块

简化版

python
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        # 第二个卷积后面不立刻 ReLU,而是先和 shortcut 相加,再统一 ReLU。
        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity
        out = self.relu(out)

        return out

Bottleneck:ResNet-50 / 101 / 152 的残差块

Bottleneck 先用 1×1 卷积降到 64 通道,然后在 64 通道上做 3×3 卷积,最后再用 1×1 卷积升回 256 通道。

简化版:

python
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        width = planes

        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = nn.BatchNorm2d(width)

        self.conv2 = conv3x3(width, width, stride)
        self.bn2 = nn.BatchNorm2d(width)

        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = nn.BatchNorm2d(planes * self.expansion)

        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out = out + identity
        out = self.relu(out)

        return out

WARNING

标准 ResNet 中,groups = 1,base_width = 64,所以 width = planes

但 ResNeXt 和 Wide ResNet 会改变 groups 或 width_per_group,因此中间通道数 width 会变化。

Bottleneck 正是通过 groups 和 base_width 来兼容普通 ResNet、ResNeXt 和 Wide ResNet。

python
width = int(planes * (base_width / 64.0)) * groups

ResNet V1.5

这里:

python
self.conv1 = conv1x1(inplanes, width)
self.conv2 = conv3x3(width, width, stride)
self.conv3 = conv1x1(width, planes * self.expansion)

TorchVision 的 Bottleneck 把下采样 stride 放在第二个 3×3 卷积上,而原论文放在第一个 1×1 卷积上;这个变体通常叫 ResNet V1.5。

初始化

python
# [stem] conv1 -> bn1 -> relu -> maxpool
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# [四个残差 stage] layer1 -> layer2 -> layer3 -> layer4
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# [分类头] avgpool -> flatten -> fc
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

layer1、layer2、layer3、layer4 不是单层网络,而是四个 stage。每个 stage 里面有多个 residual block。

python
# ResNet-18
# layer1 里有 2 个 BasicBlock
# layer2 里有 2 个 BasicBlock
# layer3 里有 2 个 BasicBlock
# layer4 里有 2 个 BasicBlock
layers = [2, 2, 2, 2]
python
# ResNet-50
# layer1 里有 3 个 Bottleneck
# layer2 里有 4 个 Bottleneck
# layer3 里有 6 个 Bottleneck
# layer4 里有 3 个 Bottleneck
layers = [3, 4, 6, 3]

_make_layer

它负责创建一个 stage,也就是创建:一个可能改变尺寸 / 通道数的 block + 若干个保持尺寸 / 通道数不变的 block

简化版:

python
# block: 使用 BasicBlock 还是 Bottleneck
# planes:这个 stage 的基础通道数
# blocks:这个 stage 里堆几个 block
# stride:第一个 block 是否进行空间下采样
# dilate:是否用 dilation 替代 stride 下采样
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    downsample = None
    # downsample 的作用是让 shortcut 分支的形状和主分支输出一致。
    if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            conv1x1(self.inplanes, planes * block.expansion, stride),
            norm_layer(planes * block.expansion),
        )

    layers = []

    layers.append(
        block(
            self.inplanes,
            planes,
            stride,
            downsample,
        )
    )

    self.inplanes = planes * block.expansion

    for _ in range(1, blocks):
        layers.append(
            block(
                self.inplanes,
                planes,
            )
        )

    return nn.Sequential(*layers)

前向传播

python
def _forward_impl(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)

    return x

ResNet-18 的形状变化

输入:

txt
[B, 3, 224, 224]

ResNet-18:

阶段输出形状
input[B, 3, 224, 224]
conv1[B, 64, 112, 112]
maxpool[B, 64, 56, 56]
layer1[B, 64, 56, 56]
layer2[B, 128, 28, 28]
layer3[B, 256, 14, 14]
layer4[B, 512, 7, 7]
avgpool[B, 512, 1, 1]
flatten[B, 512]
fc[B, 1000]

ResNet-50 的形状变化

输入:

txt
[B, 3, 224, 224]

ResNet-50:

阶段输出形状
input[B, 3, 224, 224]
conv1[B, 64, 112, 112]
maxpool[B, 64, 56, 56]
layer1[B, 256, 56, 56]
layer2[B, 512, 28, 28]
layer3[B, 1024, 14, 14]
layer4[B, 2048, 7, 7]
avgpool[B, 2048, 1, 1]
flatten[B, 2048]
fc[B, 1000]

各个 ResNet 构造函数

模型blocklayers
ResNet-18BasicBlock[2, 2, 2, 2]
ResNet-34BasicBlock[3, 4, 6, 3]
ResNet-50Bottleneck[3, 4, 6, 3]
ResNet-101Bottleneck[3, 4, 23, 3]
ResNet-152Bottleneck[3, 8, 36, 3]

修改分类头

如果要做自己的分类任务,比如 10 类分类:

python
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights

num_classes = 10

model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)

这里:

python
model.fc.in_features

会自动拿到原来 fc 的输入维度。

对于 ResNet-18:

txt
model.fc.in_features = 512

对于 ResNet-50:

txt
model.fc.in_features = 2048