欢迎访问宙启技术站
智能推送

使用torch.nn.init.constant_()在PyTorch中初始化常数值

发布时间:2023-12-24 16:04:55

PyTorch提供了torch.nn.init.constant_()函数,用于在模型中初始化权重和偏置等参数为常数值。

该函数的语法如下:

torch.nn.init.constant_(tensor, val)

参数解释:

- tensor:需要初始化的张量。

- val:常数值。

以下是torch.nn.init.constant_()的使用示例:

import torch
import torch.nn as nn
import torch.nn.init as init

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 5)

    def forward(self, x):
        x = self.fc(x)
        return x

model = MyModel()

# 打印初始化前的模型权重
print("初始化前的权重:")
print(model.fc.weight)

# 使用torch.nn.init.constant_()函数将权重初始化为常数值
val = 0.5
init.constant_(model.fc.weight, val)

# 打印初始化后的模型权重
print("
初始化后的权重:")
print(model.fc.weight)

输出结果:

初始化前的权重:
Parameter containing:
tensor([[-0.1604, -0.1743,  0.2572,  0.1146, -0.0879, -0.1713,  0.2297,  0.1105,
          0.2574,  0.2721],
        [ 0.3053,  0.1250,  0.1569, -0.2717, -0.2306,  0.0089, -0.2351,  0.1191,
         -0.1524,  0.0807],
        [ 0.0225, -0.2666, -0.2610, -0.0650,  0.2814, -0.2992, -0.1418, -0.0364,
          0.2198,  0.1500],
        [ 0.3149,  0.2732,  0.0904, -0.1273, -0.2985,  0.0585,  0.1239, -0.0553,
          0.1451, -0.0976],
        [-0.3039, -0.1594,  0.2229, -0.2436, -0.0783,  0.0415,  0.1138,  0.2192,
         -0.1616,  0.2952]], requires_grad=True)

初始化后的权重:
Parameter containing:
tensor([[0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000]], requires_grad=True)

在上述示例中,我们创建了一个简单的包含一个全连接层的模型MyModel。我们使用torch.nn.init.constant_()函数将全连接层的权重初始化为常数值0.5。在初始化之前和之后,我们打印了模型的权重。可以看到,初始化前权重是随机的,而初始化后所有的权重都变为了0.5。