使用PyTorch实现CIFAR100数据集的迁移学习
发布时间:2023-12-29 13:01:31
迁移学习(Transfer learning)是一种将在一个任务上训练好的模型应用于另一个任务上的技术。在计算机视觉领域,迁移学习常用于处理较小的数据集,通过使用在大规模数据集上训练好的模型来提高性能。
CIFAR-100是一个包含100个类别的图像数据集,每个类别有600张图像。在本篇文章中,我们将使用PyTorch实现CIFAR-100数据集的迁移学习。
首先,我们需要安装PyTorch以及相应的依赖库。可以通过以下命令完成安装:
pip install torch torchvision
安装完成后,我们可以开始编写代码。
首先,我们需要导入所需的库:
import torch import torch.nn as nn import torchvision from torchvision import transforms
接下来,我们需要加载CIFAR-100数据集。PyTorch提供了torchvision库来方便地加载常用的计算机视觉数据集。
# 定义数据预处理的转换操作
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 加载训练集
trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True, num_workers=2)
# 加载测试集
testset = torchvision.datasets.CIFAR100(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=32,
shuffle=False, num_workers=2)
接下来,我们可以定义一个预训练模型并进行微调(fine-tuning)。在这个例子中,我们将使用ResNet-18作为预训练模型,并将其最后一层替换为一个新的全连接层。
# 加载预训练模型
model = torchvision.models.resnet18(pretrained=True)
# 冻结所有卷积层参数
for param in model.parameters():
param.requires_grad = False
# 替换最后一层全连接层
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 100) # 100个输出类别
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.fc.parameters(), lr=0.001, momentum=0.9)
然后,我们可以开始迁移学习的训练过程。
# 在训练集上进行训练
for epoch in range(5): # 训练5个epoch
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 200 == 199: # 每200个小批量打印一次损失值
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
print('Finished Training')
最后,我们可以在测试集上评估模型的性能。
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
通过这个例子,我们实现了使用PyTorch进行CIFAR-100数据集的迁移学习。以上的代码示例可以帮助你理解和使用迁移学习技术,当处理小型数据集时,可以通过利用在大规模数据集上训练好的模型来提高性能。
