欢迎访问宙启技术站
智能推送

使用Python编写的torchvision.modelsmobilenet_v2模型在图像识别任务中的效果

发布时间:2023-12-12 08:29:16

MobileNetV2是一种轻量级的卷积神经网络模型,适用于移动设备和嵌入式设备等资源受限的场景。它具有较低的模型参数和计算量,同时在准确率和模型大小之间保持了一个良好的平衡。

下面是一个使用Python编写的torchvision.models.mobilenet_v2模型在图像识别任务中的示例代码:

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

# 设置随机种子
torch.manual_seed(0)

# 定义图像数据的转换
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

# 加载数据集
dataset = ImageFolder('path_to_dataset', transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 加载预训练的MobileNetV2模型
model = models.mobilenet_v2(pretrained=True)

# 替换最后一层全连接层
num_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_features, len(dataset.classes))

# 设置设备为GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (images, labels) in enumerate(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        if (i+1) % 10 == 0:
            print("Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}"
                  .format(epoch+1, num_epochs, i+1, len(dataloader), running_loss/10))
            running_loss = 0.0

# 保存模型
torch.save(model.state_dict(), 'path_to_save_model')

# 测试模型
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print('Accuracy of the model on the test images: {}%'.format(accuracy))

在示例代码中,我们首先定义了模型训练所需的转换函数,如图像尺寸调整、归一化等。然后我们使用ImageFolder加载了数据集,并通过DataLoader将其以批量形式提供给模型训练。接下来,我们加载预训练的MobileNetV2模型,并替换最后一层全连接层以匹配数据集的类别数。然后,我们定义了损失函数和优化器,并使用SGD优化器来训练模型。在训练过程中,我们计算并输出训练损失的平均值。最后,我们保存训练好的模型,并在测试集上评估模型的准确率。

这个示例展示了如何使用MobileNetV2模型进行图像识别任务。你可以根据自己的数据集和需求进行相应的修改和调整,以达到更好的识别效果。