欢迎访问宙启技术站
智能推送

Keras中的to_categorical()函数的用法及示例

发布时间:2023-12-17 09:30:58

Keras中的to_categorical()函数用于将一维的标签数据转换为one-hot编码的形式。one-hot编码是一种常见的编码方式,它将每个类别表示为一个 的向量,其中只有对应类别的索引位置为1,其余位置都为0。

to_categorical()函数的语法如下:

keras.utils.to_categorical(y, num_classes=None)

参数说明:

- y: 一维的标签数据,可以是整数数组或类似数组的对象。

- num_classes: 整数,表示输出的one-hot编码的长度。

以下是to_categorical()函数的使用示例:

# 导入必要的库
import numpy as np
from keras.utils import to_categorical

# 生成一个一维的标签数据
y = np.array([0, 1, 2, 1, 3, 0])

# 将标签数据转换为one-hot编码形式
y_one_hot = to_categorical(y)

# 打印转换后的结果
print(y_one_hot)

输出结果:

array([[1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.],
       [0., 0., 0., 1.],
       [1., 0., 0., 0.]], dtype=float32)

在上面的例子中,我们将一个一维的标签数据y [0, 1, 2, 1, 3, 0] 转换为了one-hot编码形式。一共有四个类别,所以最后生成了一个6x4的数组。每一行对应一个标签,每一列对应一个类别,只有对应类别的位置为1,其余位置都为0。

除了转换标签数据之外,to_categorical()函数还可以指定num_classes参数来指定输出的one-hot编码的长度。如果不指定num_classes参数,则默认为标签数据中的最大值加1。