ResNet18的网络架构图

首先将网络分为四层(layers),每层有两个模块组成,除了第一层是两个普通的残差块组成,其它三层有一个普通的残差块和下采样的卷积块组成。输入图像为3x224x224格式,经过卷积池化后为64x112x112格式进入主网络架构。
代码如下:
import torch
from torch import nn
from torch.nn import functional as Fclass BasicBlock(nn.Module):def __init__(self,in_channels,out_channels,kernel_size,stride):super(BasicBlock,self).__init__()self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding=1)self.bn1=nn.BatchNorm2d(out_channels)self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size,stride,padding=1)self.bn2=nn.BatchNorm2d(out_channels)def forward(self,x):output=self.bn1(self.conv1(x))output=self.bn2(self.conv2(output))return F.relu(x+output)class BasicDownBlock(nn.Module):def __init__(self,in_channels,out_channels,kernel_size,stride):super(BasicDownBlock,self).__init__() self.conv1=nn.Conv2d(in_channels,out_channels,kernel_size[0],stride[0],padding=1)self.bn1=nn.BatchNorm2d(out_channels)self.conv2=nn.Conv2d(out_channels,out_channels,kernel_size[0],stride[1],padding=1)self.bn2=nn.BatchNorm2d(out_channels)self.conv3=nn.Conv2d(in_channels,out_channels,kernel_size[1],stride[0])self.bn3=nn.BatchNorm2d(out_channels)def forward(self,x):output=self.bn1(self.conv1(x))output=self.bn2(self.conv2(output))output1=self.bn3(self.conv3(x))return F.relu(output1+output)class ResNet18(nn.Module):def __init__(self):super().__init__()# 3x224x224-->64x112x112self.conv1=nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3)self.bn1=nn.BatchNorm2d(64)# 64x112x112-->64x56x56self.pool1=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)# 64x56x56-->64x56x56self.layer1=nn.Sequential(BasicBlock(64,64,3,1),BasicBlock(64,64,3,1))# 64x56x56-->128*28*28self.layer2=nn.Sequential(BasicDownBlock(64,128,[3,1],[2,1]),BasicBlock(128,128,3,1))# 128*28*28-->256*14*14self.layer3=nn.Sequential(BasicDownBlock(128,256,[3,1],[2,1]),BasicBlock(256,256,3,1))# 256*14*14-->512x7x7self.layer4=nn.Sequential(BasicDownBlock(256,512,[7,1],[2,1]),BasicBlock(512,512,3,1))# 512x7x7-->512x1x1self.avgpool=nn.AdaptiveMaxPool2d(output_size=(1,1))self.flat=nn.Flatten()self.linear=nn.Linear(512,10)def forward(self,x):output=self.pool1(F.relu(self.bn1(self.conv1(x))))output=self.layer1(output)output=self.layer2(output)output=self.layer3(output)output=self.layer4(output)output=self.avgpool(output)output=self.flat(output)output=self.linear(output)return outputnet=ResNet18()
x=torch.randn(32,3,224,224)
print(x.shape)
y=net(x)
print(y.shape)
代码中BasicBlock为普通的残差块,注意步长和卷积核的大小,BasicDownBlock为下采样的残差块,然后将四层的网络表示出来,最后进行验证x.shape为torch.Size([32, 3, 224, 224]),y.shape为torch.Size([32, 10])。
















