残差网络ResNet


一、定义

残差指的是什么?
其中ResNet提出了两种mapping:一种是identity mapping,指的就是图1中”弯弯的曲线”,另一种residual mapping,指的就是除了”弯弯的曲线“那部分,所以最后的输出是 y=F(x)+x
identity mapping顾名思义,就是指本身,也就是公式中的xx,而residual mapping指的是“差”,也就是y−x,所以残差指的就是F(x)部分。
ResNet结构:

img

它使用了一种连接方式叫做“shortcut connection”,顾名思义,shortcut就是“抄近道”的意思

为什么ResNet可以解决“随着网络加深,准确率不下降”的问题?
理论上,对于“随着网络加深,准确率下降”的问题,Resnet提供了两种选择方式,也就是identity mapping和residual mapping,如果网络已经到达最优,继续加深网络,residual mapping将被push为0,只剩下identity mapping,这样理论上网络一直处于最优状态了,网络的性能也就不会随着深度增加而降低了。

二、代码

img
import torch.nn as nn
import torch.nn.functional as F

class basic_block(nn.Module):
    '''定义了带实线部分的残差块'''
    def __init__(self,in_channels):
        super(basic_block, self).__init__()
        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1)
        self.conv2 = nn.Conv2d(in_channels,in_channels,kernel_size=3,stride=1,padding=1)
    def forward(self, x):
        y = F.relu(self.conv1(x))
        y = self.conv2(y)
        return F.relu(x+y)

第二种是带有虚线标注的跳跃连接部分,第一种结构是在通道数不变的情况下,进行的残差结构运算,第二种的跳跃连接结构,通道数发生了改变,于是把它单独做成一个基础块,如下图:

img
class basic_block2(nn.Module):
    '''定义了带虚线部分的残差块'''
    def __init__(self,in_channels,out_channels):
        super(basic_block2, self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=2)
        self.conv2 = nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=2,padding=1)
        self.conv3 = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1)
    def forward(self, x):
        z = self.conv1(x)
        y = F.relu(self.conv2(x))
        y = self.conv3(y)
        return F.relu(y+z)

文章作者: 乔彬
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 乔彬 !
  目录