欢迎访问宙启技术站
智能推送

使用Python实现的torchvision.modelsmobilenet_v2模型的图像分类结果可视化

发布时间:2023-12-12 08:29:41

使用Python实现的torchvision.models中的mobilenet_v2模型可以用于图像分类任务。Mobilenet_v2是一种轻量级的卷积神经网络,适用于在资源有限的设备上进行实时图像分类。下面我们将演示如何使用mobilenet_v2模型进行图像分类,并将分类结果进行可视化。

首先,我们需要导入必要的库和模块:

import torch
import torchvision
from torchvision import models, transforms
from PIL import Image
import urllib.request
import matplotlib.pyplot as plt

接下来,我们需要加载mobilenet_v2模型:

model = models.mobilenet_v2(pretrained=True)

这将加载一个在ImageNet数据集上预训练的mobilenet_v2模型。接下来,我们需要定义对图像进行预处理的转换函数:

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

该转换函数将图像的大小调整为224x224像素,并对图像进行归一化处理。然后,我们可以定义一个函数来加载和预处理图像:

def load_image(url):
    with urllib.request.urlopen(url) as url_img:
        img = Image.open(url_img)
        img = preprocess(img).unsqueeze(0)
        return img

该函数将从指定的URL加载图像,并将其转换为张量。接下来,我们可以定义一个函数来进行图像分类和结果可视化:

def classify_image(image, model):
    labels = torchvision.datasets.ImageNet().classes
    output = model(image)
    _, predicted_idx = torch.max(output, 1)
    predicted_label = labels[predicted_idx.item()]
    predicted_prob = torch.nn.functional.softmax(output, dim=1)[0] * 100
    predicted_prob = predicted_prob.detach().numpy()
    
    image = image.squeeze(0).numpy().transpose((1, 2, 0))
    image = std * image + mean
    image = np.clip(image, 0, 1)

    plt.imshow(image)
    plt.axis('off')
    plt.title(f'Predicted label: {predicted_label} ({predicted_prob[predicted_idx.item()]:.2f}%)')
    plt.show()

该函数接收一个预处理的图像张量和模型作为输入,并输出预测的标签、标签的概率和可视化的图像。接下来,我们可以使用上面定义的函数进行图像分类和结果可视化:

url = 'https://example.com/image.jpg'
image = load_image(url)
classify_image(image, model)

该示例使用给定的URL加载图像,然后使用Mobilenet_v2模型进行图像分类,并可视化结果。