欢迎访问宙启技术站
智能推送

使用tensorflow.keras.callbacks中的ReduceLROnPlateau()函数自动调整学习率

发布时间:2023-12-18 10:02:00

在深度学习训练过程中,学习率的选择是十分重要的。学习率过大可能导致模型难以收敛,学习率过小又会导致训练速度缓慢。为了解决这个问题,TensorFlow提供了一个回调函数,通过观察验证集的表现来动态地调整学习率。这个回调函数叫做ReduceLROnPlateau()。

ReduceLROnPlateau()函数在验证损失不再改善时降低学习率。其定义如下:

class ReduceLROnPlateau(Callback):
    def __init__(self,
                 monitor='val_loss',
                 factor=0.1,
                 patience=10,
                 verbose=0,
                 mode='auto',
                 min_delta=1e-4,
                 cooldown=0,
                 min_lr=0,
                 **kwargs):
        super(ReduceLROnPlateau, self).__init__()

        self.monitor = monitor
        self.factor = factor
        self.patience = patience
        self.verbose = verbose
        self.mode = mode
        self.min_delta = min_delta
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.wait = 0
        self.best = np.Inf
        self.monitor_op = None
        self._reset_lr()

参数说明:

- monitor:被监测的量,通常为验证损失或验证准确率。

- factor:学习率被降低的因数。新的学习率 = 学习率 * factor。

- patience:当验证损失不再改善时,经过多少个epoch才降低学习率。

- verbose:是否打印信息。

- mode:'auto','min'或'max'之一。在min模式下,如果监测值停止减小,则学习率将减小;在max模式下,如果监测值停止增大,则学习率将减小。

- min_delta:被视为提升的最小变化,小于该值时学习率不会被降低。

- cooldown:学习率降低后恢复正常操作之前等待的epoch数。

- min_lr:学习率的下边界。

下面我们使用一个例子来展示如何使用ReduceLROnPlateau()函数:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# 加载数据
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 数据标准化
x_train, x_test = x_train / 255.0, x_test / 255.0

# 定义模型
model = Sequential()
model.add(Dense(128, activation='relu', input_shape=(784,)))
model.add(Dense(10, activation='softmax'))

# 编译模型
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 定义学习率回调
lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5)

# 训练模型
model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test), callbacks=[lr_callback])

在这个例子中,我们使用了MNIST数据集进行手写数字识别。模型的优化器采用了Adam,损失函数采用了sparse_categorical_crossentropy。我们使用callbacks参数将lr_callback作为回调函数传入fit()函数中。

在每个epoch结束后,ReduceLROnPlateau()函数会判断验证损失是否有改善。如果连续5个epoch验证损失没有改善,学习率将以0.2的因数进行降低。