欢迎访问宙启技术站
智能推送

Python中如何使用CSVLogger()记录数据

发布时间:2023-12-23 22:30:31

在Python中,可以使用CSVLogger()来记录数据。CSVLogger()keras库中的一个函数,用于将训练过程中的信息保存到CSV文件中,如训练损失和准确率等。以下是使用CSVLogger()的示例:

首先,导入所需的库和模块:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.callbacks import CSVLogger

然后,准备训练数据和模型:

(X_train, y_train), (X_test, y_test) = keras.datasets.mnist.load_data()
X_train = X_train / 255.0
X_test = X_test / 255.0

model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

接下来,创建一个CSVLogger对象,指定要保存日志的文件路径:

csv_logger = CSVLogger('training.log')

然后,使用fit()函数来训练模型,并将CSVLogger对象作为callbacks参数传入:

model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test), callbacks=[csv_logger])

训练过程中,日志信息将会保存到training.log文件中。

可以使用pandas库读取和查看保存的日志文件:

import pandas as pd
df = pd.read_csv('training.log')
print(df.head())

输出结果:

   epoch  accuracy      loss  val_accuracy  val_loss
0      0  0.897067  0.321148        0.9617  0.138226
1      1  0.964333  0.124198        0.9700  0.093779
2      2  0.976217  0.081814        0.9722  0.085722
3      3  0.982233  0.058765        0.9732  0.082745
4      4  0.985600  0.045758        0.9726  0.089980

通过以上使用CSVLogger()的示例,可以方便地记录模型的训练过程信息,并以CSV格式保存到文件中。