欢迎访问宙启技术站
智能推送

Python中decode_predictions()函数的用法和实例

发布时间:2024-01-18 07:36:37

decode_predictions()函数是一个在Keras中由utils.py模块提供的方法,用于将模型的输出转换为易读的标签。

该函数主要用于将预测结果的向量编码转换为具体的标签名称。

使用该函数需要先导入decode_predictions方法:

from keras.applications.resnet50 import decode_predictions

decode_predictions()函数接受一个预测输出的数组作为参数,并将其转换为易于理解的列表。它返回一个具有以下结构的列表:

[
    ('标签1名称', 预测概率1),
    ('标签2名称', 预测概率2),
    ...
]

下面是一个使用decode_predictions()函数的示例:

from keras.applications.resnet50 import decode_predictions
from keras.applications.resnet50 import preprocess_input, ResNet50
from keras.preprocessing import image
import numpy as np

model = ResNet50(weights='imagenet')

img_path = 'elephant.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)

preds = model.predict(x)
decoded_preds = decode_predictions(preds, top=3)

for i, (label, prob) in enumerate(decoded_preds[0]):
    print(f"第 {i+1} 个预测结果是:{label}, 概率为:{prob}")

这个例子中,我们使用了ResNet50模型加载了预训练的权重。然后我们加载了一张图片,并对其进行了预处理。接下来,我们使用模型对图片进行预测,并将预测结果传递给decode_predictions()函数。最后,我们遍历预测结果,并打印出每个预测结果及其对应的概率。

这是一个在ImageNet数据集上使用ResNet50模型进行图像分类的例子。decode_predictions()函数的作用是将预测结果的向量编码转换为易于理解的标签名称,使我们能够更好地理解模型的输出。