深度学习 Day 20——优化器对比实验

news2025/7/30 9:56:00

深度学习 Day 20——优化器对比实验

文章目录

  • 深度学习 Day 20——优化器对比实验
    • 一、前言
    • 二、我的环境
    • 三、前期工作
      • 1、设置GPU
      • 2、导入数据
      • 3、配置数据集
      • 4、数据可视化
    • 三、构建模型
    • 四、训练模型
    • 五、模型评估
      • 1、Accuracy与Loss图
      • 2、评估模型
    • 六、最后我想说

一、前言

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍦 参考文章:365天深度学习训练营-第11周:优化器对比实验(训练营内部成员可读)
  • 🍖 原作者:K同学啊|接辅导、项目定制

在上一期数据增强实验中,我们将TensorFlow版本升级到了2.4.0,可能有些库会出现不兼容异常,大家需要版本对应一下。

本期博客,我们将着眼于深度学习中的各种优化器对比进行学习。

二、我的环境

  • 电脑系统:Windows 11
  • 语言环境:Python 3.8.5
  • 编译器:DataSpell 2022.2
  • 深度学习环境:TensorFlow 2.4.0
  • 显卡及显存:RTX 3070 8G

三、前期工作

1、设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

from tensorflow          import keras
import matplotlib.pyplot as plt
import pandas            as pd
import numpy             as np
import warnings,os,PIL,pathlib

warnings.filterwarnings("ignore")             #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False    # 用来正常显示负号

2、导入数据

本期使用的数据集跟之前的好莱坞明星识别使用的数据集一样。

data_dir    = "/content/gdrive/MyDrive/data"
data_dir    = pathlib.Path(data_dir)
image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:",image_count)
图片总数为: 1800
batch_size = 16
img_height = 336
img_width  = 336

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=12,
    image_size=(img_height, img_width),
    batch_size=batch_size)
Found 1800 files belonging to 17 classes.
Using 1440 files for training.
Found 1800 files belonging to 17 classes.
Using 360 files for validation.

查看一下数据文件标签:

class_names = train_ds.class_names
print(class_names)
['Angelina Jolie', 'Brad Pitt', 'Denzel Washington', 'Hugh Jackman', 'Jennifer Lawrence', 'Johnny Depp', 'Kate Winslet', 'Leonardo DiCaprio', 'Megan Fox', 'Natalie Portman', 'Nicole Kidman', 'Robert Downey Jr', 'Sandra Bullock', 'Scarlett Johansson', 'Tom Cruise', 'Tom Hanks', 'Will Smith']

3、配置数据集

AUTOTUNE = tf.data.AUTOTUNE

def train_preprocessing(image,label):
    return (image/255.0,label)

train_ds = (
    train_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)           # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

val_ds = (
    val_ds.cache()
    .shuffle(1000)
    .map(train_preprocessing)    # 这里可以设置预处理函数
#     .batch(batch_size)         # 在image_dataset_from_directory处已经设置了batch_size
    .prefetch(buffer_size=AUTOTUNE)
)

4、数据可视化

plt.figure(figsize=(10, 8))  # 图形的宽为10高为5
plt.suptitle("数据展示")

for images, labels in train_ds.take(1):
    for i in range(15):
        plt.subplot(4, 5, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)

        # 显示图片
        plt.imshow(images[i])
        # 显示标签
        plt.xlabel(class_names[labels[i]-1])

plt.show()

在这里插入图片描述

三、构建模型

from tensorflow.keras.layers import Dropout,Dense,BatchNormalization
from tensorflow.keras.models import Model

def create_model(optimizer='adam'):
    # 加载预训练模型
    vgg16_base_model = tf.keras.applications.vgg16.VGG16(weights='imagenet',
                                                                include_top=False,
                                                                input_shape=(img_width, img_height, 3),
                                                                pooling='avg')
    for layer in vgg16_base_model.layers:
        layer.trainable = False

    X = vgg16_base_model.output
    
    X = Dense(170, activation='relu')(X)
    X = BatchNormalization()(X)
    X = Dropout(0.5)(X)

    output = Dense(len(class_names), activation='softmax')(X)
    vgg16_model = Model(inputs=vgg16_base_model.input, outputs=output)

    vgg16_model.compile(optimizer=optimizer,
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy'])
    return vgg16_model

model1 = create_model(optimizer=tf.keras.optimizers.Adam())
model2 = create_model(optimizer=tf.keras.optimizers.SGD())
model2.summary()

打印的网络结构:

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
58892288/58889256 [==============================] - 8s 0us/step
Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 336, 336, 3)]     0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 336, 336, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 336, 336, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 168, 168, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 168, 168, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 168, 168, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 84, 84, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 84, 84, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 84, 84, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 84, 84, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 42, 42, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 42, 42, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 42, 42, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 42, 42, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 21, 21, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 21, 21, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 21, 21, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 21, 21, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 10, 10, 512)       0         
_________________________________________________________________
global_average_pooling2d_1 ( (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 170)               87210     
_________________________________________________________________
batch_normalization_1 (Batch (None, 170)               680       
_________________________________________________________________
dropout_1 (Dropout)          (None, 170)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 17)                2907      
=================================================================
Total params: 14,805,485
Trainable params: 90,457
Non-trainable params: 14,715,028
_________________________________________________________________

在这里我们是直接从网上下载VGG16模型并使用,可能会出现下载失败的情况,例如:

Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5: None -- [WinError 10054] 远程主机强迫关闭了一个现有的连接。

这种情况就是网络问题,导致无法下载,可以多试几次看看,如果一直都无法下载的话,可以直接进入上面错误中的网址:https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5它就会自动下载该模型,我们将其保存在项目文件夹中,然后我们在上面代码调用VGG16模型的时候里面的weights参数的值改成下载的VGG模型对应的地址即可。

在这里我们使用了两种优化器进行对比,Adam和SGD并对其两者进行简单的介绍:

  • Adam

    keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    

    它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。Adam的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。堆内存的需求比较小,也适用于大数据集和更高维空间的模型。

  • SGD

    keras.optimizers.SGD(lr=0.01, momentum=0.0, decay=0.0, nesterov=False)
    

    它是一种随机梯度下降优化器,SGD就是每一次迭代计算mini-batch的梯度,然后对参数进行更新,是最常见的优化方法了。

四、训练模型

NO_EPOCHS = 50

history_model1  = model1.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
history_model2  = model2.fit(train_ds, epochs=NO_EPOCHS, verbose=1, validation_data=val_ds)
Epoch 1/50
90/90 [==============================] - 113s 1s/step - loss: 2.8072 - accuracy: 0.1535 - val_loss: 2.7235 - val_accuracy: 0.0556
Epoch 2/50
90/90 [==============================] - 20s 221ms/step - loss: 2.0860 - accuracy: 0.3243 - val_loss: 2.4607 - val_accuracy: 0.2833
Epoch 3/50
90/90 [==============================] - 21s 238ms/step - loss: 1.8125 - accuracy: 0.4132 - val_loss: 2.2316 - val_accuracy: 0.2972
Epoch 4/50
90/90 [==============================] - 20s 224ms/step - loss: 1.5680 - accuracy: 0.5146 - val_loss: 1.9419 - val_accuracy: 0.4361
Epoch 5/50
90/90 [==============================] - 20s 225ms/step - loss: 1.4038 - accuracy: 0.5681 - val_loss: 1.6831 - val_accuracy: 0.4833
Epoch 6/50
90/90 [==============================] - 20s 224ms/step - loss: 1.2327 - accuracy: 0.6153 - val_loss: 1.6376 - val_accuracy: 0.4944
Epoch 7/50
90/90 [==============================] - 20s 223ms/step - loss: 1.1563 - accuracy: 0.6486 - val_loss: 1.6727 - val_accuracy: 0.4417
Epoch 8/50
90/90 [==============================] - 20s 224ms/step - loss: 1.0707 - accuracy: 0.6694 - val_loss: 1.4806 - val_accuracy: 0.5250
Epoch 9/50
90/90 [==============================] - 20s 224ms/step - loss: 0.9549 - accuracy: 0.7125 - val_loss: 1.6010 - val_accuracy: 0.4889
Epoch 10/50
90/90 [==============================] - 20s 224ms/step - loss: 0.8829 - accuracy: 0.7347 - val_loss: 1.7179 - val_accuracy: 0.4611
Epoch 11/50
90/90 [==============================] - 20s 223ms/step - loss: 0.8417 - accuracy: 0.7389 - val_loss: 1.7174 - val_accuracy: 0.4833
Epoch 12/50
90/90 [==============================] - 20s 225ms/step - loss: 0.7601 - accuracy: 0.7708 - val_loss: 1.5996 - val_accuracy: 0.4833
Epoch 13/50
90/90 [==============================] - 20s 224ms/step - loss: 0.7254 - accuracy: 0.7757 - val_loss: 1.6183 - val_accuracy: 0.5278
Epoch 14/50
90/90 [==============================] - 20s 224ms/step - loss: 0.6863 - accuracy: 0.8014 - val_loss: 1.7551 - val_accuracy: 0.4722
Epoch 15/50
90/90 [==============================] - 20s 224ms/step - loss: 0.6336 - accuracy: 0.8069 - val_loss: 1.8830 - val_accuracy: 0.4639
Epoch 16/50
90/90 [==============================] - 20s 224ms/step - loss: 0.5819 - accuracy: 0.8319 - val_loss: 1.4917 - val_accuracy: 0.5389
Epoch 17/50
90/90 [==============================] - 20s 224ms/step - loss: 0.5748 - accuracy: 0.8340 - val_loss: 1.8751 - val_accuracy: 0.4694
Epoch 18/50
90/90 [==============================] - 20s 223ms/step - loss: 0.5219 - accuracy: 0.8396 - val_loss: 2.0875 - val_accuracy: 0.4861
Epoch 19/50
90/90 [==============================] - 20s 224ms/step - loss: 0.4934 - accuracy: 0.8556 - val_loss: 1.9038 - val_accuracy: 0.5028
Epoch 20/50
90/90 [==============================] - 20s 224ms/step - loss: 0.4942 - accuracy: 0.8514 - val_loss: 1.6452 - val_accuracy: 0.5444
Epoch 21/50
90/90 [==============================] - 20s 224ms/step - loss: 0.4933 - accuracy: 0.8431 - val_loss: 2.1585 - val_accuracy: 0.4472
Epoch 22/50
90/90 [==============================] - 20s 225ms/step - loss: 0.4514 - accuracy: 0.8701 - val_loss: 2.0218 - val_accuracy: 0.4972
Epoch 23/50
90/90 [==============================] - 20s 223ms/step - loss: 0.4458 - accuracy: 0.8694 - val_loss: 1.6499 - val_accuracy: 0.5417
Epoch 24/50
90/90 [==============================] - 20s 224ms/step - loss: 0.3927 - accuracy: 0.8917 - val_loss: 2.3310 - val_accuracy: 0.4222
Epoch 25/50
90/90 [==============================] - 20s 224ms/step - loss: 0.3870 - accuracy: 0.8854 - val_loss: 1.6200 - val_accuracy: 0.5583
Epoch 26/50
90/90 [==============================] - 20s 224ms/step - loss: 0.3800 - accuracy: 0.8861 - val_loss: 1.9285 - val_accuracy: 0.5361
Epoch 27/50
90/90 [==============================] - 20s 224ms/step - loss: 0.3792 - accuracy: 0.8771 - val_loss: 2.3675 - val_accuracy: 0.4806
Epoch 28/50
90/90 [==============================] - 20s 224ms/step - loss: 0.3321 - accuracy: 0.8986 - val_loss: 1.7445 - val_accuracy: 0.5500
Epoch 29/50
90/90 [==============================] - 20s 224ms/step - loss: 0.3185 - accuracy: 0.9076 - val_loss: 1.7202 - val_accuracy: 0.5639
Epoch 30/50
90/90 [==============================] - 20s 224ms/step - loss: 0.3436 - accuracy: 0.8958 - val_loss: 1.6614 - val_accuracy: 0.5667
Epoch 31/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2917 - accuracy: 0.9118 - val_loss: 2.0079 - val_accuracy: 0.5500
Epoch 32/50
90/90 [==============================] - 20s 224ms/step - loss: 0.3325 - accuracy: 0.8868 - val_loss: 2.0677 - val_accuracy: 0.5028
Epoch 33/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2879 - accuracy: 0.9146 - val_loss: 1.6412 - val_accuracy: 0.6028
Epoch 34/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2856 - accuracy: 0.9111 - val_loss: 2.1213 - val_accuracy: 0.5222
Epoch 35/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2645 - accuracy: 0.9153 - val_loss: 2.0940 - val_accuracy: 0.5222
Epoch 36/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2528 - accuracy: 0.9160 - val_loss: 1.8489 - val_accuracy: 0.5389
Epoch 37/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2553 - accuracy: 0.9208 - val_loss: 1.8388 - val_accuracy: 0.5583
Epoch 38/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2362 - accuracy: 0.9285 - val_loss: 1.8624 - val_accuracy: 0.5667
Epoch 39/50
90/90 [==============================] - 20s 223ms/step - loss: 0.2245 - accuracy: 0.9229 - val_loss: 1.9156 - val_accuracy: 0.5639
Epoch 40/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2198 - accuracy: 0.9333 - val_loss: 2.2192 - val_accuracy: 0.5556
Epoch 41/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2144 - accuracy: 0.9278 - val_loss: 1.8951 - val_accuracy: 0.5833
Epoch 42/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2074 - accuracy: 0.9389 - val_loss: 2.0159 - val_accuracy: 0.5500
Epoch 43/50
90/90 [==============================] - 20s 225ms/step - loss: 0.2166 - accuracy: 0.9257 - val_loss: 2.2641 - val_accuracy: 0.5111
Epoch 44/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2312 - accuracy: 0.9264 - val_loss: 2.0438 - val_accuracy: 0.5750
Epoch 45/50
90/90 [==============================] - 20s 223ms/step - loss: 0.2248 - accuracy: 0.9257 - val_loss: 2.2686 - val_accuracy: 0.5472
Epoch 46/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2102 - accuracy: 0.9375 - val_loss: 2.2441 - val_accuracy: 0.5583
Epoch 47/50
90/90 [==============================] - 20s 224ms/step - loss: 0.2120 - accuracy: 0.9340 - val_loss: 2.3860 - val_accuracy: 0.5361
Epoch 48/50
90/90 [==============================] - 20s 224ms/step - loss: 0.1959 - accuracy: 0.9354 - val_loss: 2.4052 - val_accuracy: 0.5167
Epoch 49/50
90/90 [==============================] - 20s 224ms/step - loss: 0.1699 - accuracy: 0.9521 - val_loss: 2.5167 - val_accuracy: 0.5250
Epoch 50/50
90/90 [==============================] - 20s 224ms/step - loss: 0.1645 - accuracy: 0.9528 - val_loss: 2.1405 - val_accuracy: 0.5722
Epoch 1/50
90/90 [==============================] - 21s 226ms/step - loss: 3.0785 - accuracy: 0.0986 - val_loss: 2.7949 - val_accuracy: 0.1000
Epoch 2/50
90/90 [==============================] - 20s 223ms/step - loss: 2.5472 - accuracy: 0.1924 - val_loss: 2.6379 - val_accuracy: 0.1583
Epoch 3/50
90/90 [==============================] - 20s 225ms/step - loss: 2.2651 - accuracy: 0.2694 - val_loss: 2.4596 - val_accuracy: 0.2528
Epoch 4/50
90/90 [==============================] - 20s 224ms/step - loss: 2.0612 - accuracy: 0.3389 - val_loss: 2.2347 - val_accuracy: 0.3389
Epoch 5/50
90/90 [==============================] - 20s 224ms/step - loss: 1.9508 - accuracy: 0.3653 - val_loss: 2.0695 - val_accuracy: 0.3972
Epoch 6/50
90/90 [==============================] - 20s 224ms/step - loss: 1.8406 - accuracy: 0.4021 - val_loss: 1.9282 - val_accuracy: 0.3917
Epoch 7/50
90/90 [==============================] - 20s 224ms/step - loss: 1.7565 - accuracy: 0.4451 - val_loss: 1.8469 - val_accuracy: 0.4111
Epoch 8/50
90/90 [==============================] - 20s 223ms/step - loss: 1.6587 - accuracy: 0.4667 - val_loss: 1.7935 - val_accuracy: 0.4306
Epoch 9/50
90/90 [==============================] - 20s 224ms/step - loss: 1.5934 - accuracy: 0.4889 - val_loss: 1.6561 - val_accuracy: 0.4528
Epoch 10/50
90/90 [==============================] - 20s 223ms/step - loss: 1.5516 - accuracy: 0.4854 - val_loss: 1.7235 - val_accuracy: 0.3944
Epoch 11/50
90/90 [==============================] - 20s 224ms/step - loss: 1.4753 - accuracy: 0.5403 - val_loss: 1.6903 - val_accuracy: 0.4333
Epoch 12/50
90/90 [==============================] - 20s 223ms/step - loss: 1.4309 - accuracy: 0.5389 - val_loss: 1.6633 - val_accuracy: 0.4556
Epoch 13/50
90/90 [==============================] - 20s 225ms/step - loss: 1.4168 - accuracy: 0.5437 - val_loss: 1.6759 - val_accuracy: 0.4667
Epoch 14/50
90/90 [==============================] - 20s 223ms/step - loss: 1.3726 - accuracy: 0.5701 - val_loss: 1.7004 - val_accuracy: 0.4667
Epoch 15/50
90/90 [==============================] - 20s 224ms/step - loss: 1.2890 - accuracy: 0.5924 - val_loss: 1.6371 - val_accuracy: 0.4639
Epoch 16/50
90/90 [==============================] - 20s 223ms/step - loss: 1.2669 - accuracy: 0.6139 - val_loss: 1.5207 - val_accuracy: 0.4806
Epoch 17/50
90/90 [==============================] - 20s 223ms/step - loss: 1.2238 - accuracy: 0.6097 - val_loss: 1.5294 - val_accuracy: 0.4972
Epoch 18/50
90/90 [==============================] - 20s 224ms/step - loss: 1.1582 - accuracy: 0.6375 - val_loss: 1.4838 - val_accuracy: 0.5111
Epoch 19/50
90/90 [==============================] - 20s 223ms/step - loss: 1.1518 - accuracy: 0.6271 - val_loss: 1.5244 - val_accuracy: 0.5111
Epoch 20/50
90/90 [==============================] - 20s 224ms/step - loss: 1.1324 - accuracy: 0.6438 - val_loss: 1.5217 - val_accuracy: 0.4917
Epoch 21/50
90/90 [==============================] - 21s 237ms/step - loss: 1.0931 - accuracy: 0.6590 - val_loss: 1.4744 - val_accuracy: 0.5056
Epoch 22/50
90/90 [==============================] - 20s 224ms/step - loss: 1.0524 - accuracy: 0.6667 - val_loss: 1.4386 - val_accuracy: 0.5167
Epoch 23/50
90/90 [==============================] - 20s 224ms/step - loss: 1.0196 - accuracy: 0.6729 - val_loss: 1.4282 - val_accuracy: 0.5278
Epoch 24/50
90/90 [==============================] - 20s 224ms/step - loss: 1.0143 - accuracy: 0.6924 - val_loss: 1.5158 - val_accuracy: 0.5361
Epoch 25/50
90/90 [==============================] - 20s 224ms/step - loss: 0.9708 - accuracy: 0.6875 - val_loss: 1.5623 - val_accuracy: 0.4806
Epoch 26/50
90/90 [==============================] - 20s 223ms/step - loss: 0.9651 - accuracy: 0.6875 - val_loss: 1.3693 - val_accuracy: 0.5611
Epoch 27/50
90/90 [==============================] - 20s 223ms/step - loss: 0.9384 - accuracy: 0.7076 - val_loss: 1.4377 - val_accuracy: 0.5556
Epoch 28/50
90/90 [==============================] - 20s 224ms/step - loss: 0.8951 - accuracy: 0.7285 - val_loss: 1.4171 - val_accuracy: 0.5222
Epoch 29/50
90/90 [==============================] - 20s 224ms/step - loss: 0.8706 - accuracy: 0.7340 - val_loss: 1.6458 - val_accuracy: 0.5167
Epoch 30/50
90/90 [==============================] - 20s 224ms/step - loss: 0.8520 - accuracy: 0.7375 - val_loss: 1.4419 - val_accuracy: 0.5139
Epoch 31/50
90/90 [==============================] - 20s 223ms/step - loss: 0.8547 - accuracy: 0.7188 - val_loss: 1.2940 - val_accuracy: 0.5889
Epoch 32/50
90/90 [==============================] - 20s 223ms/step - loss: 0.8222 - accuracy: 0.7424 - val_loss: 1.4509 - val_accuracy: 0.5528
Epoch 33/50
90/90 [==============================] - 20s 223ms/step - loss: 0.8406 - accuracy: 0.7299 - val_loss: 1.4598 - val_accuracy: 0.5306
Epoch 34/50
90/90 [==============================] - 20s 225ms/step - loss: 0.7983 - accuracy: 0.7528 - val_loss: 1.5114 - val_accuracy: 0.5472
Epoch 35/50
90/90 [==============================] - 20s 224ms/step - loss: 0.7992 - accuracy: 0.7403 - val_loss: 1.4475 - val_accuracy: 0.5750
Epoch 36/50
90/90 [==============================] - 20s 224ms/step - loss: 0.7557 - accuracy: 0.7569 - val_loss: 1.5024 - val_accuracy: 0.5389
Epoch 37/50
90/90 [==============================] - 20s 224ms/step - loss: 0.7298 - accuracy: 0.7681 - val_loss: 1.4272 - val_accuracy: 0.5389
Epoch 38/50
90/90 [==============================] - 20s 223ms/step - loss: 0.7378 - accuracy: 0.7632 - val_loss: 1.3973 - val_accuracy: 0.5778
Epoch 39/50
90/90 [==============================] - 20s 224ms/step - loss: 0.7025 - accuracy: 0.7875 - val_loss: 1.3738 - val_accuracy: 0.5500
Epoch 40/50
90/90 [==============================] - 20s 224ms/step - loss: 0.6812 - accuracy: 0.7958 - val_loss: 1.5651 - val_accuracy: 0.5361
Epoch 41/50
90/90 [==============================] - 20s 224ms/step - loss: 0.6646 - accuracy: 0.7854 - val_loss: 1.4765 - val_accuracy: 0.5667
Epoch 42/50
90/90 [==============================] - 20s 224ms/step - loss: 0.6477 - accuracy: 0.8021 - val_loss: 1.5985 - val_accuracy: 0.5361
Epoch 43/50
90/90 [==============================] - 20s 224ms/step - loss: 0.6508 - accuracy: 0.8042 - val_loss: 1.3467 - val_accuracy: 0.5667
Epoch 44/50
90/90 [==============================] - 20s 225ms/step - loss: 0.6539 - accuracy: 0.7889 - val_loss: 1.3919 - val_accuracy: 0.5778
Epoch 45/50
90/90 [==============================] - 20s 224ms/step - loss: 0.6402 - accuracy: 0.8104 - val_loss: 1.3426 - val_accuracy: 0.5917
Epoch 46/50
90/90 [==============================] - 20s 224ms/step - loss: 0.6178 - accuracy: 0.8076 - val_loss: 1.4094 - val_accuracy: 0.5833
Epoch 47/50
90/90 [==============================] - 20s 223ms/step - loss: 0.6083 - accuracy: 0.8000 - val_loss: 1.3747 - val_accuracy: 0.5750
Epoch 48/50
90/90 [==============================] - 20s 224ms/step - loss: 0.6079 - accuracy: 0.8028 - val_loss: 1.5148 - val_accuracy: 0.5583
Epoch 49/50
90/90 [==============================] - 20s 224ms/step - loss: 0.6115 - accuracy: 0.8000 - val_loss: 1.9661 - val_accuracy: 0.4556
Epoch 50/50
90/90 [==============================] - 20s 224ms/step - loss: 0.5785 - accuracy: 0.8146 - val_loss: 1.4971 - val_accuracy: 0.5500

五、模型评估

1、Accuracy与Loss图

from matplotlib.ticker import MultipleLocator
plt.rcParams['savefig.dpi'] = 300 #图片像素
plt.rcParams['figure.dpi']  = 300 #分辨率

acc1     = history_model1.history['accuracy']
acc2     = history_model2.history['accuracy']
val_acc1 = history_model1.history['val_accuracy']
val_acc2 = history_model2.history['val_accuracy']

loss1     = history_model1.history['loss']
loss2     = history_model2.history['loss']
val_loss1 = history_model1.history['val_loss']
val_loss2 = history_model2.history['val_loss']

epochs_range = range(len(acc1))

plt.figure(figsize=(16, 4))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, acc1, label='Training Accuracy-Adam')
plt.plot(epochs_range, acc2, label='Training Accuracy-SGD')
plt.plot(epochs_range, val_acc1, label='Validation Accuracy-Adam')
plt.plot(epochs_range, val_acc2, label='Validation Accuracy-SGD')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss1, label='Training Loss-Adam')
plt.plot(epochs_range, loss2, label='Training Loss-SGD')
plt.plot(epochs_range, val_loss1, label='Validation Loss-Adam')
plt.plot(epochs_range, val_loss2, label='Validation Loss-SGD')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
   
# 设置刻度间隔,x轴每1一个刻度
ax = plt.gca()
ax.xaxis.set_major_locator(MultipleLocator(1))

plt.show()

在这里插入图片描述

2、评估模型

def test_accuracy_report(model):
    score = model.evaluate(val_ds, verbose=0)
    print('Loss function: %s, accuracy:' % score[0], score[1])
    
test_accuracy_report(model2)
Loss function: 1.49705171585083, accuracy: 0.550000011920929

六、最后我想说

本期的博客到这里就结束了,最近我的电脑出现了一些问题导致无法对复杂的模型进行训练,最近考虑准备重装一下系统并清理一下电脑了,本次实验我是在Google Colaboratory上run的,这个平台目前我使用下来感觉不还错,提供免费的算力对我来说够了,大家如果自己的电脑配置不够的话也可以去试试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/35179.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

分析linux启动内核源码

内核的启动时从main.c这个文件里面的start_kernel函数开始的,这个文件在linux源码里面的init文件夹下面 下面我们来看看这个函数 这个函数很长,可以看个大概过去 asmlinkage __visible void __init start_kernel(void) {char *command_line;char *afte…

MCE | 靶向相分离 小分子药物研发

细胞内的各种组分如何在正确的时间、地点上聚集并执行其相应的功能,是生命科学领域内的一大问题。近些年来,细胞内一些没有细胞膜结构包被的“细胞器” (Membrane-less organelles/condensates)——又称生物分子凝聚体 (Biomolecular condensates) 逐渐引…

Analyzing User-Level Privacy Attack Against Federated Learning

Analyzing User-Level Privacy Attack Against Federated Learning IEEE JSAC CCF-A期刊 宋梦凯(武汉大学网络安全实验室) Summary 提出了针对FL用户级隐私的基于GAN的攻击(mGAN-AI),主要是从每个client的更新中计算…

UniPro助力半导体企业之低代码平台篇:高效协同快速响应

在《UniPro助力半导体企业之特色篇:缺陷管理覆盖全流程》中,我们介绍了UniPro如何帮助半导体企业完成在研发过程中的Bug管理,然而缺陷管理也并非UniPro的全部,除此之外,UniPro有着完整的项目管理体系,涵盖了…

相控阵天线(二):非规则直线阵列天线(稀布阵列、稀疏阵列、平方率分布阵列)

目录非规则线阵概述不均匀递变间距阵列稀布阵列稀疏阵列不均匀相位递变阵列不均匀幅度激励阵列代码示例非规则线阵概述 非规则线阵主要包括以下情况: 1. 不均匀间距阵列: a)不均匀间距递变阵列:单元间距按照一定的系数递增&#…

傻白入门芯片设计,IP, MCM, SiP, SoC 和 Chiplet的区别(二)

一、IP: 早期的复制电路都是全定制,比如Intel的4004cpu,这种设计非常耗时。考虑到cpu的很多模块有相似的地方,能不能把这些东西模块化?于是就有了IP核的概念,Intelligent Property,即知识产权核…

智慧运维解决方案-最新全套文件

智慧运维解决方案-最新全套文件一、智能运维的必然性二、建设思路三、建设方案1、IT资产和配置管理2、自动化运维管理3、一体化运维平台四、获取 - 智慧运维全套最新解决方案合集一、智能运维的必然性 运维场景多样化。随着IT业务持续增长,为保证业务连续性&#xf…

相控阵天线(八):圆环阵列天线和球面阵列天线

目录圆环阵圆环阵方向图函数均匀圆环阵示例圆环阵层间距的影响非均匀圆环阵示例球面阵列球面阵方向图函数球面阵示例圆环阵 多个单元分布在一个圆环上的阵列称为圆环阵列。这是一种有实际意义的阵列结构,可应用于无线电测向、导航、地下探测等系统中。 圆环阵方向…

微服务介绍

目录一、系统架构演变单体应用架构垂直应用架构分布式架构SOA架构微服务架构二、微服务架构介绍微服务架构常见问题微服务架构常见概念服务治理服务调用服务网关服务容错链路追踪微服务架构常见问题解决方案ServiceComdServiceCloudServiceCloud AlibabaSpringCloud Alibaba介绍…

[论文评析-CV]MediaPipe: A Framework for Building Perception Pipelines, ArXiv,2019

MediaPipe: A Framework for Building Perception Pipelines文章信息前言框架介绍MediaPipe用于目标检测(1)Detection branch:(2)Tracking branch:MediaPipe框架重要的概念调度其他References文章信息 论文题目:MediaPipe: A Framework for Building Perception Pi…

mysql误删数据后 快速恢复的办法

手抖不小心把表里的数据删除或修改错误怎么办?该如何快速恢复呢?遇到这样的问题怎么办?希望下面这篇文章能够帮助到你! 第一步:保证mysql已经开启binlog,查看命令: 查看binklog是否开启 show…

运动耳机品牌排行榜前十名有哪些,2022年六款运动耳机值得入手

近几年来,运动健身潮流一直都非常火热,但一个人运动难免会感到枯燥,这个时候最需要的就是音乐的陪伴了,佩戴着运动耳机听音乐,运动的时间也会过得越来越快,不过在选购运动耳机的过程会比挑选普通蓝牙耳机还…

pycharm社区版不能使用conda

修改成cmd 本质是conda init 问题 专业版的直接在终端改成cmd就行了

(DS90UB3702TRURRQ1) LT8640SHV-2低噪声降压稳压器QFN

LT8640/LT8640-1降压稳压器采用Silent Switcher架构,设计用于最大限度地降低EMI/EMC辐射并在高达3MHz的频率下提供高效率。由于具有2.5μA的超低静态电流(当输出处于全面调节状态时),因此适用于要求在非常小负载电流条件下获得极高…

ICP算法加速优化--多线程和GPU

LZ之前的文章ICP算法实现(C) 用C实现了基础的ICP算法,由于该算法是一种迭代的优化算法,里面含有大量循环操作以及矩阵运算,可以通过使用多线程或者GPU硬件来进行加速,具体分别可以通过OpenMP和CUDA编程实现…

六、【React基础】组件实例三大核心属性之三 refs + 事件处理

文章目录1、字符串形式的ref(过时/不推荐)2、回调形式的ref(推荐!!!)● 回调ref中回调次数的问题3、createRef创建ref容器(最新最推荐)4、事件处理理解:组件内…

Web3D应用开发在线IDE【中文版】

nunuStudio 是一个Web 3D应用程序的集成开发环境,它提供用于在 3D 世界中创建和编辑对象的工具,支持JavaScript和Python对3D场景进行二次开发。nunuStudio中文版 由 BimAnt 提供。 如果你曾经使用过其他类似的框架(unity、playcanvas、godot …

Spring Boot 3.0 正式发布了!一个超重要的版本!!

首发于 JavaGuide (「Java学习面试指南」一份涵盖大部分 Java 程序员所需要掌握的核心知识。准备 Java 面试,首选 JavaGuide!) 紧跟着 Spring Framework 6.0 的正式发布,就在昨天,Spring Boot 3.0 也正式发布了! 这是一…

供应多臂PEG衍生物4-Arm PEG-Azide,4-Arm PEG-N3,四臂-聚乙二醇-叠氮

1、名称 英文:4-Arm PEG-Azide,4-Arm PEG-N3 中文:四臂-聚乙二醇-叠氮 2、CAS编号:N/A 3、所属分类:Azide PEG Multi-arm PEGs 4、分子量:可定制,四臂-PEG 2000-叠氮、4-Arm PEG-N3 20000、…

深入分析序列化和反序列化原理,终于知道serialVersionUID到底有什么用了

一个问题引发的思考 下面是一个简单的socket通信demo。 通信数据类: package com.zwx.serialize.demo; public class SocketUser { public SocketUser(String id, String name) {this.id id;this.name name; }private String id; private String name;public St…