欢迎访问宙启技术站
智能推送

用Python构建CIFARNet神经网络进行图像分类

发布时间:2024-01-06 15:47:22

CIFARNet是一个用于CIFAR-10图像分类任务的卷积神经网络架构。CIFAR-10数据集是一个经典的计算机视觉数据集,包含10个不同的图像类别,每个类别有6000个32x32彩色图像。在本文中,我们将使用Python构建CIFARNet网络,并给出一个简单的图像分类的例子。

首先,我们需要导入所需的库,包括PyTorch和torchvision。PyTorch是一个使用CUDA加速的开源深度学习框架,而torchvision则提供了一些用于计算机视觉的工具和数据集。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

接下来,定义CIFARNet网络的结构。CIFARNet包含了一些卷积层、池化层和全连接层。在这个例子中,我们使用了两个卷积层和两个全连接层。

class CIFARNet(nn.Module):
    def __init__(self):
        super(CIFARNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.pool(self.relu1(self.conv1(x)))
        x = self.pool(self.relu2(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x

接下来,我们准备CIFAR-10数据集。首先,定义数据集的路径和转换操作。

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                            transform=transforms.ToTensor())

然后,创建数据加载器,用于批量加载训练和测试数据。

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64,
                                           shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64,
                                          shuffle=False, num_workers=2)

接下来,初始化CIFARNet网络和优化器。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = CIFARNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

然后,定义训练函数和测试函数。

def train(net, device, train_loader, optimizer, criterion):
    net.train()
    running_loss = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(train_loader)

def test(net, device, test_loader, criterion):
    net.eval()
    test_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return test_loss / len(test_loader), correct / total

接下来,我们可以开始训练和测试CIFARNet网络。

for epoch in range(10):
    train_loss = train(net, device, train_loader, optimizer, criterion)
    test_loss, accuracy = test(net, device, test_loader, criterion)
    print("Epoch: {}, Train Loss: {:.4f}, Test Loss: {:.4f}, Accuracy: {:.2f}%".format(
        epoch, train_loss, test_loss, accuracy * 100))

在每个训练周期之后,我们将输出训练损失、测试损失和准确率。训练和测试的结果将会如下所示:

Epoch: 0, Train Loss: 2.2567, Test Loss: 2.0814, Accuracy: 21.24%
Epoch: 1, Train Loss: 1.9565, Test Loss: 1.7858, Accuracy: 33.57%
Epoch: 2, Train Loss: 1.7289, Test Loss: 1.6218, Accuracy: 40.86%
...

可以根据需要,调整网络的层数和参数,以便获得更好的分类效果。这个例子只是CIFARNet的一个基本实现,可以作为入门级的参考。CIFARNet网络的原论文提供了更多调优的细节,可以进一步探索。