欢迎访问宙启技术站
智能推送

Python实现的LeNet模型参数生成

发布时间:2023-12-11 06:23:33

LeNet是一种经典的卷积神经网络架构,由Yann LeCun在1998年提出,用于手写数字识别任务。它有五层:两个卷积层(C1和C3)、两个池化层(S2和S4)以及一个全连接层(F5)。在这里,我们将使用Python实现LeNet模型,并生成模型的参数。

首先,我们需要导入必要的库:

import torch
import torch.nn as nn
import torch.nn.functional as F

然后,我们定义LeNet模型的类,继承自nn.Module

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        
        # 定义卷积层
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        
        # 定义全连接层
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        # 前向传播
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

__init__方法中,我们定义了LeNet模型的结构。nn.Conv2d是卷积层,nn.Linear是全连接层。在forward方法中,我们定义了前向传播的流程。

接着,我们可以创建LeNet模型的实例:

model = LeNet()

现在,我们可以使用state_dict方法来获取模型的参数:

model_parameters = model.state_dict()

state_dict方法返回一个字典,其中包含了模型的所有参数和缓冲区。

接下来,我们可以将这些参数保存到文件中,以备以后使用:

torch.save(model_parameters, 'lenet.pth')

现在,模型的参数已经保存到了lenet.pth文件中。

如果我们想要加载模型参数,可以使用load_state_dict方法:

model = LeNet()
model.load_state_dict(torch.load('lenet.pth'))

这样,模型的参数就会被加载到对应的层中。

下面是一个完整的示例,展示了如何构建LeNet模型、生成参数、保存参数,并使用加载的参数进行预测:

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

def main():
    model = LeNet()
    
    model_parameters = model.state_dict()
    torch.save(model_parameters, 'lenet.pth')
    
    new_model = LeNet()
    new_model.load_state_dict(torch.load('lenet.pth'))
    
    # 使用加载的参数进行预测
    input_tensor = torch.randn(1, 1, 32, 32)
    output = new_model(input_tensor)
    print(output)
    
if __name__ == '__main__':
    main()

通过上述步骤,我们实现了LeNet模型,生成了模型的参数,并通过加载参数进行了预测。这个例子可以作为学习使用PyTorch构建和训练模型的参考。