当前位置: 代码迷 >> 综合 >> pytorch model.load_state_dict报错
  详细解决方案

pytorch model.load_state_dict报错

热度:19   发布时间:2023-12-27 13:04:48.0

pytorch加载模型的时候如果模型里边使用了一些判断,判断作为选择执行条件,但是也保存到模型里面了,但是调用的时候不选择判断条件里边的网络并且使用load_state_dict,会报错,有些算子找不到名称。如:

if backbone == "mobilenet":self.backbone = mobilenet()flat_shape = 1024elif backbone == "inception_resnetv1":self.backbone = inception_resnet()
else:raise ValueError('Unsupported backbone - `{}`, Use mobilenet, inception_resnetv1.'.format(backbone))self.avg = nn.AdaptiveAvgPool2d((1,1))self.Bottleneck = nn.Linear(flat_shape, embedding_size,bias=False)self.last_bn = nn.BatchNorm1d(embedding_size, eps=0.001, momentum=0.1, affine=True)if mode == "train": # 判断条件,测试时,不加载全连接self.classifier = nn.Linear(embedding_size, num_classes)

可以加入strict=False选项,规避网络中没有调用的算子:

model2.load_state_dict(state_dict2, strict=False)

  相关解决方案