当前位置: 代码迷 >> 综合 >> MNIST图像分类 - Keras
  详细解决方案

MNIST图像分类 - Keras

热度:131   发布时间:2023-09-22 06:02:07.0

1. 序贯模型的网络拓扑

实例中的小模型,由输入层、卷积层、池化层、扁平层和稠密层,以及输出层等网络层构成。
MNIST图像分类 - Keras

2. 代码实现

图像分类的网络拓扑相对简单,准备数据集、构建序贯模型、编译模型、训练模型,并且评估模型。

# encoding=utf-8 **import keras
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input, Reshape, Dense, Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D
from keras.datasets import mnist# 设置参数
num_classes = 10
batch_size = 32
epochs = 30# 图像维度
img_rows, img_cols = 28, 28# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
y_train = keras.utils.to_categorical(y_train, num_classes=num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes=num_classes)# 设置输入数据维度
input_shape = (img_rows, img_cols)
inputs = Input(input_shape)
print(input_shape + (1, ))# 构建网络模型
x = Reshape(input_shape + (1, ), input_shape=input_shape)(inputs)
conv1 = Conv2D(14, kernel_size=3, activation='relu')(x)
pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
conv2 = Conv2D(7, kernel_size=3, activation='relu')(pool1)
pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
flatten = Flatten()(pool2)
output = Dense(10, activation='sigmoid')(flatten)
model = Model(inputs=inputs, outputs=output)# 输出网络图
print(model.summary())# 绘制网络图
plot_model(model, to_file='convolutional_neural_network.png', show_shapes=True, show_layer_names=True)# 优化函数
opt = keras.optimizers.RMSprop(lr=1e-4, decay=1e-6)# 编译模型
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])# 训练模型
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,validation_data=(x_test, y_test), shuffle=True)# 评估模型
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])# 绘制准确率曲线
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()# 绘制损失率曲线
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

3. 训练结果

评估一个模型的好坏,可以从精确率、召回率、损失率等方面评价。

Test loss: 0.06894081085920334
Test accuracy: 0.9801999926567078

准确率MNIST图像分类 - Keras

损失率MNIST图像分类 - Keras

  相关解决方案