如何通过Python中的onnx.save()函数保存PyTorch模型为ONNX文件
发布时间:2024-01-11 06:30:30
要将PyTorch模型保存为ONNX文件,可以使用PyTorch提供的onnx库中的save()函数。该函数接受两个参数:要保存的模型和要保存到的文件路径。下面是一个使用示例:
首先,需要安装PyTorch和onnx库。可以通过pip安装它们:
pip install torch pip install onnx
然后,导入必要的库:
import torch import torchvision import torch.onnx as onnx
接下来,将要保存的模型加载到内存中。这里以ResNet18模型为例:
model = torchvision.models.resnet18(pretrained=True)
在保存模型之前,需要将其设置为评估模式:
model.eval()
接下来,创建一个输入张量。这个张量用于模型的前向传播,并将其保存为一个变量:
input_image = torch.zeros((1,3,224,224))
然后,将模型和输入张量传递给onnx.export()函数,并指定要保存的文件路径:
onnx.export(model, input_image, 'resnet18.onnx', export_params=True)
在此示例中,'resnet18.onnx'是要保存的ONNX文件的路径。export_params参数设置为True,以保存模型的参数。
最后,运行上述代码将模型保存为ONNX文件。可以使用torch.onnx.export()函数的onnx_producer_name和onnx_producer_version参数来设置ONNX文件的生产者名称和版本。
完整的示例代码如下:
import torch import torchvision import torch.onnx as onnx # 导入模型 model = torchvision.models.resnet18(pretrained=True) model.eval() # 创建输入张量 input_image = torch.zeros((1,3,224,224)) # 保存模型为ONNX文件 onnx.export(model, input_image, 'resnet18.onnx', export_params=True, opset_version=10)
这样,就可以将PyTorch模型保存为ONNX文件。ONNX文件可以在其他框架中加载和使用。
