残差网络
ResNet [2015]
官方 PyTorch 实现:torchvision.models.resnet
图:一个正常块(左图)和一个残差块(右图)
图:网络架构比较
ResNet 要解决什么问题
在 ResNet 出现之前,一个很自然的想法是:
网络越深,表达能力越强,效果应该越好。
但实验里会遇到一个反直觉的问题:当网络越来越深时,训练误差反而可能变高。
注意,这里说的不是过拟合。过拟合通常表现为:
- 训练集效果很好;
- 测试集效果变差。
而 ResNet 论文里关心的 degradation problem 是:
- 网络加深以后;
- 训练集误差也变差;
- 说明模型连训练数据都没学好。
这就说明问题不只是模型容量,而是 深层网络本身不好优化。
ResNet 的核心做法就是给网络加一条“捷径”:
其中:
是输入; 是卷积层真正要学习的变化; 这条路直接跳过去,叫做 shortcut / skip connection; - 最后把两条路相加。
TIP
普通网络希望每一层直接学到目标映射:
ResNet 不让卷积层直接学
最后再通过:
把结果还原出来。
这就是 residual learning,也就是“残差学习”。
从梯度角度看
残差块写成:
如果损失函数是
这里的
它的直观意义是:
梯度不一定非要穿过所有卷积层,至少还有一条直接路径可以往前传。
这就是为什么残差连接能缓解深层网络中的优化困难。
代码
resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
resnet34 = ResNet(BasicBlock, [3, 4, 6, 3])
resnet50 = ResNet(Bottleneck, [3, 4, 6, 3])
resnet101 = ResNet(Bottleneck, [3, 4, 23, 3])
resnet152 = ResNet(Bottleneck, [3, 8, 36, 3])辅助卷积函数
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation,
)bias=False因为卷积后面通常接 BatchNorm2d。BN 有自己的可学习仿射参数,所以卷积 bias 通常可以省掉。padding=dilation普通卷积时 dilation=1,所以 padding=1,这样卷积在 stride=1 时不会改变特征图尺寸。空洞卷积时,padding 会跟着 dilation 变大。 groups默认是 1。普通 ResNet 用 groups=1,ResNeXt 会用更大的 groups。
def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False,
)它主要有两个作用:
- 在 Bottleneck 里面改变通道数;
- 在 downsample 分支里对 shortcut 做投影,让 shortcut 的形状和主分支输出对齐。
卷积主要负责“换通道”,有时候也通过 stride 顺便改变空间尺寸。
BasicBlock:ResNet-18 / ResNet-34 的残差块
简化版
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
# 第二个卷积后面不立刻 ReLU,而是先和 shortcut 相加,再统一 ReLU。
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out + identity
out = self.relu(out)
return outBottleneck:ResNet-50 / 101 / 152 的残差块
Bottleneck 先用
简化版:
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
width = planes
self.conv1 = conv1x1(inplanes, width)
self.bn1 = nn.BatchNorm2d(width)
self.conv2 = conv3x3(width, width, stride)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out = out + identity
out = self.relu(out)
return outWARNING
标准 ResNet 中,groups = 1,base_width = 64,所以 width = planes
但 ResNeXt 和 Wide ResNet 会改变 groups 或 width_per_group,因此中间通道数 width 会变化。
Bottleneck 正是通过 groups 和 base_width 来兼容普通 ResNet、ResNeXt 和 Wide ResNet。
width = int(planes * (base_width / 64.0)) * groupsResNet V1.5
这里:
self.conv1 = conv1x1(inplanes, width)
self.conv2 = conv3x3(width, width, stride)
self.conv3 = conv1x1(width, planes * self.expansion)TorchVision 的 Bottleneck 把下采样 stride 放在第二个
初始化
# [stem] conv1 -> bn1 -> relu -> maxpool
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# [四个残差 stage] layer1 -> layer2 -> layer3 -> layer4
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
# [分类头] avgpool -> flatten -> fc
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)layer1、layer2、layer3、layer4 不是单层网络,而是四个 stage。每个 stage 里面有多个 residual block。
# ResNet-18
# layer1 里有 2 个 BasicBlock
# layer2 里有 2 个 BasicBlock
# layer3 里有 2 个 BasicBlock
# layer4 里有 2 个 BasicBlock
layers = [2, 2, 2, 2]# ResNet-50
# layer1 里有 3 个 Bottleneck
# layer2 里有 4 个 Bottleneck
# layer3 里有 6 个 Bottleneck
# layer4 里有 3 个 Bottleneck
layers = [3, 4, 6, 3]_make_layer
它负责创建一个 stage,也就是创建:一个可能改变尺寸 / 通道数的 block + 若干个保持尺寸 / 通道数不变的 block
简化版:
# block: 使用 BasicBlock 还是 Bottleneck
# planes:这个 stage 的基础通道数
# blocks:这个 stage 里堆几个 block
# stride:第一个 block 是否进行空间下采样
# dilate:是否用 dilation 替代 stride 下采样
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
downsample = None
# downsample 的作用是让 shortcut 分支的形状和主分支输出一致。
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(
block(
self.inplanes,
planes,
stride,
downsample,
)
)
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(
self.inplanes,
planes,
)
)
return nn.Sequential(*layers)前向传播
def _forward_impl(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return xResNet-18 的形状变化
输入:
[B, 3, 224, 224]ResNet-18:
| 阶段 | 输出形状 |
|---|---|
| input | [B, 3, 224, 224] |
| conv1 | [B, 64, 112, 112] |
| maxpool | [B, 64, 56, 56] |
| layer1 | [B, 64, 56, 56] |
| layer2 | [B, 128, 28, 28] |
| layer3 | [B, 256, 14, 14] |
| layer4 | [B, 512, 7, 7] |
| avgpool | [B, 512, 1, 1] |
| flatten | [B, 512] |
| fc | [B, 1000] |
ResNet-50 的形状变化
输入:
[B, 3, 224, 224]ResNet-50:
| 阶段 | 输出形状 |
|---|---|
| input | [B, 3, 224, 224] |
| conv1 | [B, 64, 112, 112] |
| maxpool | [B, 64, 56, 56] |
| layer1 | [B, 256, 56, 56] |
| layer2 | [B, 512, 28, 28] |
| layer3 | [B, 1024, 14, 14] |
| layer4 | [B, 2048, 7, 7] |
| avgpool | [B, 2048, 1, 1] |
| flatten | [B, 2048] |
| fc | [B, 1000] |
各个 ResNet 构造函数
| 模型 | block | layers |
|---|---|---|
| ResNet-18 | BasicBlock | [2, 2, 2, 2] |
| ResNet-34 | BasicBlock | [3, 4, 6, 3] |
| ResNet-50 | Bottleneck | [3, 4, 6, 3] |
| ResNet-101 | Bottleneck | [3, 4, 23, 3] |
| ResNet-152 | Bottleneck | [3, 8, 36, 3] |
修改分类头
如果要做自己的分类任务,比如 10 类分类:
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
num_classes = 10
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, num_classes)这里:
model.fc.in_features会自动拿到原来 fc 的输入维度。
对于 ResNet-18:
model.fc.in_features = 512对于 ResNet-50:
model.fc.in_features = 2048