当前位置: 代码迷 >> 综合 >> keras深度学习之猫狗分类三(特征提取)
  详细解决方案

keras深度学习之猫狗分类三(特征提取)

热度:98   发布时间:2023-11-24 18:59:52.0

想要将深度学习应用于小型图像数据集,一种常用且非常高效的方法是使用预训练网络。预训练网络(pretrained network)是一个保存好的网络,之前已在大型数据集(通常是大规模图像分类任务)上训练好。如果这个原始数据集足够大且足够通用,那么预训练网络学到的特征的空间层次结构可以有效地作为视觉世界的通用模型,因此这些特征可用于各种不同的计算机视觉问题,即使这些新问题涉及的类别和原始任务完全不同。举个例子,你在 ImageNet 上训练了一个网络(其类别主要是动物和日常用品),然后将这个训练好的网络应用于某个不相干的任务,比如在图像中识别家具。这种学到的特征在不同问题之间的可移植性,是深度学习与许多早期浅层学习方法相比的重要优势,它使得深度学习对小数据问题非常有效。
本例中,假设有一个在 ImageNet 数据集(140 万张标记图像,1000 个不同的类别)上训练好的大型卷积神经网络。ImageNet 中包含许多动物类别,其中包括不同种类的猫和狗,因此可以认为它在猫狗分类问题上也能有良好的表现。
使用预训练网络有两种方法:特征提取(feature extraction)和微调模型(fine-tuning)。两种方法我们都会介绍。首先来看特征提取。
特征提取是使用之前网络学到的表示来从新样本中提取出有趣的特征。然后将这些特征输入一个新的分类器,从头开始训练。
如前所述,用于图像分类的卷积神经网络包含两部分:首先是一系列池化层和卷积层,最后是一个密集连接分类器。第一部分叫作模型的卷积基(convolutional base)。对于卷积神经网络而言,特征提取就是取出之前训练好的网络的卷积基,在上面运行新数据,然后在输出上面训练一个新的分类器。
 保持卷积基不变,改变分类器
上图表示的为保持卷积基不变,改变分类器。
为什么仅重复使用卷积基?我们能否也重复使用密集连接分类器?一般来说,应该避免这么做。原因在于卷积基学到的表示可能更加通用,因此更适合重复使用。卷积神经网络的特征图表示通用概念在图像中是否存在,无论面对什么样的计算机视觉问题,这种特征图都可能很有用。但是,分类器学到的表示必然是针对于模型训练的类别,其中仅包含某个类别出现在整张图像中的概率信息。此外,密集连接层的表示不再包含物体在输入图像中的位置信息。密集连接层舍弃了空间的概念,而物体位置信息仍然由卷积特征图所描述。如果物体位置对于问题很重要,那么密集连接层的特征在很大程度上是无用的。
注意,某个卷积层提取的表示的通用性(以及可复用性)取决于该层在模型中的深度。模型中更靠近底部的层提取的是局部的、高度通用的特征图(比如视觉边缘、颜色和纹理),而更靠近顶部的层提取的是更加抽象的概念(比如“猫耳朵”或“狗眼睛”)。因此,如果你的新数据集与原始模型训练的数据集有很大差异,那么最好只使用模型的前几层来做特征提取,而不是使用整个卷积基。
本文中使用的卷积基为vgg16,我们来打印下vgg16的网络结构:

from tensorflow.keras.applications import VGG16''' weights 指定模型初始化的权重检查点。include_top 指定模型最后是否包含密集连接分类器。默认情况下,这个密集连接分 类器对应于 ImageNet 的 1000 个类别。因为我们打算使用自己的密集连接分类器(只有 两个类别:cat 和 dog),所以不需要包含它。input_shape 是输入到网络中的图像张量的形状。这个参数完全是可选的,如果不传 入这个参数,那么网络能够处理任意形状的输入。 '''
conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(150,150,3))conv_base.summary()

网络结构为:

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         [(None, 150, 150, 3)]     0
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 150, 150, 64)      1792
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 150, 150, 64)      36928
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 75, 75, 64)        0
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 75, 75, 128)       73856
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 75, 75, 128)       147584
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 37, 37, 128)       0
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 37, 37, 256)       295168
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 37, 37, 256)       590080
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 37, 37, 256)       590080
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 18, 18, 256)       0
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 18, 18, 512)       1180160
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 18, 18, 512)       2359808
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 18, 18, 512)       2359808
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 9, 9, 512)         0
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 9, 9, 512)         2359808
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 9, 9, 512)         2359808
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 9, 9, 512)         2359808
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 4, 4, 512)         0
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0

最后的特征图形状为 (4, 4, 512)。我们将在这个特征上添加一个密集连接分类器。
接下来,下一步有两种方法可供选择。
1、在你的数据集上运行卷积基,将输出保存成硬盘中的 Numpy 数组,然后用这个数据作为输入,输入到独立的密集连接分类器中(与本书第一部分介绍的分类器类似)。这种方法速度快,计算代价低,因为对于每个输入图像只需运行一次卷积基,而卷积基是目前流程中计算代价最高的。但出于同样的原因,这种方法不允许你使用数据增强。
2、在顶部添加 Dense 层来扩展已有模型(即 conv_base),并在输入数据上端到端地运行整个模型。这样你可以使用数据增强,因为每个输入图像进入模型时都会经过卷积基。但出于同样的原因,这种方法的计算代价比第一种要高很多。

1 不使用数据增强的快速特征提取

训练代码如下:

from cProfile import label
from statistics import mode
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import optimizers
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image as kimage
from tensorflow.keras.applications import VGG16
import numpy as npconv_base=VGG16(weights='imagenet',include_top=False,input_shape=(150,150,3))#训练样本的目录
train_dir='./dataset/training_set/'
#验证样本的目录
validation_dir='./dataset/validation_set/'
#测试样本目录
test_dir='./dataset/test_set/'datagen=ImageDataGenerator(rescale=1./255)
batch_size=20
#卷积基提取特征
def extract_features(dir,sample_count):features=np.zeros(shape=(sample_count,4,4,512))labels=np.zeros(shape=(sample_count,))generator=datagen.flow_from_directory(dir,target_size=(150,150),batch_size=batch_size,class_mode='binary')i=0for inputs_batch,lables_batch in generator:features_batch=conv_base.predict(inputs_batch)features[i*batch_size:(i+1)*batch_size]=features_batchlabels[i*batch_size:(i+1)*batch_size]=lables_batchi+=1if i*batch_size>=sample_count:breakreturn features,labelsif __name__=='__main__':#提取卷积特征train_features, train_labels = extract_features(train_dir, 3200) validation_features, validation_labels = extract_features(validation_dir, 800) test_features, test_labels = extract_features(test_dir, 1000)#将特征展平 以便传入全连接层train_features=train_features.reshape(3200,-1)validation_features=validation_features.reshape(800,-1)test_features=test_features.reshape(1000,-1)#构建训练网络model=models.Sequential()model.add(layers.Dense(units=256,activation='relu',input_dim=4*4*512))model.add(layers.Dropout(rate=0.25))model.add(layers.Dense(units=1,activation='sigmoid'))model.compile(optimizer=optimizers.RMSprop(lr=2e-5),loss='binary_crossentropy',metrics=['acc'])history = model.fit(train_features, train_labels,epochs=30,batch_size=20,validation_data=(validation_features, validation_labels))test_eval=model.evaluate(x=test_features,y=test_labels)print(test_eval)acc = history.history['acc']val_acc = history.history['val_acc']loss = history.history['loss']val_loss = history.history['val_loss']epochs = range(1, len(acc) + 1)plt.plot(epochs, acc, 'bo', label='Training acc')plt.plot(epochs, val_acc, 'b', label='Validation acc')plt.title('Training and validation accuracy')plt.legend()plt.figure()plt.plot(epochs, loss, 'bo', label='Training loss')plt.plot(epochs, val_loss, 'b', label='Validation loss')plt.title('Training and validation loss')plt.legend()plt.show()

准确率的变化曲线如下:
在这里插入图片描述
损失函数的变化曲线如下:
在这里插入图片描述
采用这种方法的验证集准确率可以达到90%。

2 采用数据增强的特征提取

训练代码如下:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import optimizers
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing import image as kimage
from tensorflow.keras.applications import VGG16
import numpy as np''' ImageDataGenerator可完成读取图像数据 读取图像文件 将jpeg图像解码为RGB像素网络 将这些像素转换到为浮点型张量并缩放到0~1之间 '''
#训练样本的目录
train_dir='./dataset/training_set/'
#验证样本的目录
validation_dir='./dataset/validation_set/'
#测试样本目录
test_dir='./dataset/test_set/'#训练样本生成器
#注意数据增强只能用于训练数据,不能用于验证数据和测试数据
''' 进行数据增强 '''
#设置数据增强
''' rotation_range 是角度值(在 0~180 范围内),表示图像随机旋转的角度范围。 width_shift 和 height_shift 是图像在水平或垂直方向上平移的范围(相对于总宽 度或总高度的比例)。 shear_range 是随机错切变换的角度。 zoom_range 是图像随机缩放的范围。 horizontal_flip 是随机将一半图像水平翻转。如果没有水平不对称的假设(比如真 实世界的图像),这种做法是有意义的。 fill_mode是用于填充新创建像素的方法,这些新像素可能来自于旋转或宽度/高度平移。 '''
train_datagen=ImageDataGenerator(rescale=1./255,rotation_range=40,width_shift_range=0.2,height_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest')train_generator=train_datagen.flow_from_directory(directory=train_dir,target_size=(150,150),class_mode='binary',batch_size=20
)#验证样本生成器
validation_datagen=ImageDataGenerator(rescale=1./255)
validation_generator=train_datagen.flow_from_directory(directory=validation_dir,target_size=(150,150),class_mode='binary',batch_size=20
)#测试样本生成器
test_datagen=ImageDataGenerator(rescale=1./255)
test_generator=train_datagen.flow_from_directory(directory=test_dir,target_size=(150,150),class_mode='binary',batch_size=20
)if __name__=='__main__':conv_base=VGG16(weights='imagenet',include_top=False,input_shape=(150,150,3))#冻结卷积基 保证其权重在训练过程中不变conv_base.trainable=False#构建训练网络model=models.Sequential()model.add(conv_base)model.add(layers.Flatten())model.add(layers.Dense(units=256,activation='relu'))model.add(layers.Dense(units=1,activation='sigmoid'))model.compile(optimizer=optimizers.RMSprop(learning_rate=1e-4),loss='binary_crossentropy',metrics=['acc'])model.summary()#拟合模型history=model.fit_generator(train_generator,steps_per_epoch=100,epochs=100,validation_data=validation_generator,validation_steps=50)#测试测试集的准确率test_eval=model.evaluate_generator(test_generator)print(test_eval)acc = history.history['acc']val_acc = history.history['val_acc']loss = history.history['loss']val_loss = history.history['val_loss']epochs = range(1, len(acc) + 1)plt.plot(epochs, acc, 'bo', label='Training acc')plt.plot(epochs, val_acc, 'b', label='Validation acc')plt.title('Training and validation accuracy')plt.legend()plt.figure()plt.plot(epochs, loss, 'bo', label='Training loss')plt.plot(epochs, val_loss, 'b', label='Validation loss')plt.title('Training and validation loss')plt.legend()plt.show()

使用数据增强的准确率变化曲线:
在这里插入图片描述
使用数据增强的损失函数变化曲线:
在这里插入图片描述

  相关解决方案