欢迎访问宙启技术站
智能推送

使用PyTorch实现CIFAR100数据集的迁移学习

发布时间:2023-12-29 13:01:31

迁移学习(Transfer learning)是一种将在一个任务上训练好的模型应用于另一个任务上的技术。在计算机视觉领域,迁移学习常用于处理较小的数据集,通过使用在大规模数据集上训练好的模型来提高性能。

CIFAR-100是一个包含100个类别的图像数据集,每个类别有600张图像。在本篇文章中,我们将使用PyTorch实现CIFAR-100数据集的迁移学习。

首先,我们需要安装PyTorch以及相应的依赖库。可以通过以下命令完成安装:

pip install torch torchvision

安装完成后,我们可以开始编写代码。

首先,我们需要导入所需的库:

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms

接下来,我们需要加载CIFAR-100数据集。PyTorch提供了torchvision库来方便地加载常用的计算机视觉数据集。

# 定义数据预处理的转换操作
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 加载训练集
trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=True, num_workers=2)

# 加载测试集
testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                         shuffle=False, num_workers=2)

接下来,我们可以定义一个预训练模型并进行微调(fine-tuning)。在这个例子中,我们将使用ResNet-18作为预训练模型,并将其最后一层替换为一个新的全连接层。

# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)

# 冻结所有卷积层参数
for param in model.parameters():
    param.requires_grad = False
    
# 替换最后一层全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 100)  # 100个输出类别

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)

然后,我们可以开始迁移学习的训练过程。

# 在训练集上进行训练
for epoch in range(5):  # 训练5个epoch
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 200 == 199:  # 每200个小批量打印一次损失值
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0

print('Finished Training')

最后,我们可以在测试集上评估模型的性能。

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

通过这个例子,我们实现了使用PyTorch进行CIFAR-100数据集的迁移学习。以上的代码示例可以帮助你理解和使用迁移学习技术,当处理小型数据集时,可以通过利用在大规模数据集上训练好的模型来提高性能。