当前位置: 代码迷 >> 综合 >> Pytorch学习(三)--用50行代码搭建ResNet
  详细解决方案

Pytorch学习(三)--用50行代码搭建ResNet

热度:74   发布时间:2023-09-24 05:04:15.0
#------------------------------用50行代码搭建ResNet-------------------------------------------
from torch import nn
import torch as t
from torch.nn import functional as Fclass ResidualBlock(nn.Module):#实现子module: Residual    Blockdef __init__(self,inchannel,outchannel,stride=1,shortcut=None):super(ResidualBlock,self).__init__()self.left=nn.Sequential(nn.Conv2d(inchannel,outchannel,3,stride,1,bias=False),nn.BatchNorm2d(outchannel),nn.ReLU(inplace=True),nn.Conv2d(outchannel,outchannel,3,1,1,bias=False),nn.BatchNorm2d(outchannel))self.right=shortcutdef forward(self,x):out=self.left(x)residual=x if self.right is None else self.right(x)out+=residualreturn F.relu(out)class ResNet(nn.Module):#实现主module:ResNet34#ResNet34包含多个layer,每个layer又包含多个residual block#用子module实现residual block , 用 _make_layer 函数实现layerdef __init__(self,num_classes=1000):super(ResNet,self).__init__()self.pre=nn.Sequential(nn.Conv2d(3,64,7,2,3,bias=False),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(3,2,1))#重复的layer,分别有3,4,6,3个residual blockself.layer1=self._make_layer(64,64,3)self.layer2=self._make_layer(64,128,4,stride=2)self.layer3=self._make_layer(128,256,6,stride=2)self.layer4=self._make_layer(256,512,3,stride=2)#分类用的全连接self.fc=nn.Linear(512,num_classes)def _make_layer(self,inchannel,outchannel,block_num,stride=1):#构建layer,包含多个residual blockshortcut=nn.Sequential(nn.Conv2d(inchannel,outchannel,1,stride,bias=False),nn.BatchNorm2d(outchannel))layers=[ ]layers.append(ResidualBlock(inchannel,outchannel,stride,shortcut))for i in range(1,block_num):layers.append(ResidualBlock(outchannel,outchannel))return nn.Sequential(*layers)def forward(self,x):x=self.pre(x)x=self.layer1(x)x=self.layer2(x)x=self.layer3(x)x=self.layer4(x)x=F.avg_pool2d(x,7)x=x.view(x.size(0),-1)return self.fc(x)
model=ResNet()
input=t.autograd.Variable(t.randn(1,3,224,224))
o=model(input)
print(o)

 

大致框架算是理解了,但是细节部分比如卷积层的输入输出的大小之类的,还需要仔细研究。

 

Pytorch学习系列(一)至(四)均摘自《深度学习框架PyTorch入门与实践》陈云

  相关解决方案