欢迎访问宙启技术站
智能推送

Keras中to_categorical()函数的使用方法及示例

发布时间:2023-12-17 09:35:27

在Keras中,to_categorical()函数用于将整数标签转换为分类矩阵。

使用方法:

to_categorical(y, num_classes=None, dtype='float32')

参数说明:

- y:整数数组或标签向量。

- num_classes:要生成的分类数。如果未指定,则将根据y中的最大值确定分类数。

- dtype:生成的矩阵的数据类型。

示例:

假设我们有一个包含3个类别的分类任务,并且有以下整数标签:

[0, 2, 1, 1, 0]

使用to_categorical()函数将整数标签转换为分类矩阵的代码如下:

from tensorflow.keras.utils import to_categorical

labels = [0, 2, 1, 1, 0]
categorical_labels = to_categorical(labels, num_classes=3)

print(categorical_labels)

输出结果为:

[[1. 0. 0.]
 [0. 0. 1.]
 [0. 1. 0.]
 [0. 1. 0.]
 [1. 0. 0.]]

在上面的例子中,原始的整数标签列表[0, 2, 1, 1, 0]被转换为一个3列的分类矩阵,每一列代表一个类别。1的位置表示对应的样本属于该类别,0的位置表示不属于该类别。

这个函数可以用于将整数标签转换为适合于训练分类模型的输入形式。例如,在一个多分类任务中,我们可以使用to_categorical()函数将整数标签转换为适合于输入到神经网络的独热编码形式的输入。