当前位置: 代码迷 >> 综合 >> image classification fashion mnist
  详细解决方案

image classification fashion mnist

热度:41   发布时间:2024-02-23 07:36:24.0
import tensorflow as tf
import matplotlib.pyplot as plt
from matplotlib.pyplot import MultipleLocator
# tensorflow==2.1.0# 载入并准备好 MNIST 数据集。将样本从整数转换为浮点数:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0# 0~9
class_nums = 10
epochs = 30  # 5# 将模型的各层堆叠起来,以搭建 tf.keras.Sequential 模型。为训练选择优化器和损失函数:
model = tf.keras.models.Sequential([tf.keras.layers.Flatten(input_shape=(28, 28)),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(class_nums, activation='softmax')
])# 训练并验证模型:
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=epochs, validation_data=(x_test, y_test))
model.evaluate(x_test, y_test, verbose=2)
# 现在,这个照片分类器的准确度已经达到 88%。想要了解更多,请阅读 TensorFlow 教程。# 图可视化
def pltshow(loss, val_loss, accuracy, val_accuracy):epochs_range = range(epochs)plt.figure(figsize=(8, 8))plt.subplot(1, 2, 1)plt.plot(epochs_range, accuracy, label='Training Accuracy')plt.plot(epochs_range, val_accuracy, label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.xlabel('Epoch', fontsize=14)plt.ylabel('Accuracy', fontsize=14)ax = plt.gca()# ax.xaxis.set_major_locator(MultipleLocator(5))ax.yaxis.set_major_locator(MultipleLocator(0.01))plt.xlim(0, epochs)plt.ylim(0.8, 1)plt.grid()plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='Training Loss')plt.plot(epochs_range, val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.xlabel('Epoch', fontsize=14)plt.ylabel('Loss', fontsize=14)ax = plt.gca()# ax.xaxis.set_major_locator(MultipleLocator(5))ax.yaxis.set_major_locator(MultipleLocator(0.05))plt.xlim(0, epochs)plt.ylim(0, 0.6)plt.grid()plt.show()# 训练可视化
def history_show(history):loss = history['loss']val_loss = history['val_loss']accuracy = history['accuracy']val_accuracy = history['val_accuracy']pltshow(loss, val_loss, accuracy, val_accuracy)history_show(history.history)# 2020-09-27 guangjinzheng fashion mnist course

 

  相关解决方案