欢迎访问宙启技术站
智能推送

PyTorch中交叉熵损失函数的参数解读

发布时间:2023-12-31 12:33:38

在PyTorch中,交叉熵损失函数(CrossEntropyLoss)是一种用于分类问题的损失函数。在使用交叉熵损失函数之前,需要先定义一个模型进行训练。

交叉熵损失函数的主要参数有两个:input和target。其中,input是一个二维张量,代表模型的预测结果;target是一个一维张量,代表真实标签。input和target的形状必须相同。

对于input张量,每行代表一个样本的预测结果,每列代表一个类别。通常,我们可以使用softmax函数将input中的预测结果转化为概率分布。例如,假设有一个样本有三个类别的预测结果为[1.0, 2.0, 3.0],经过softmax函数处理后,可能变为[0.090, 0.244, 0.665]。

对于target张量,每个元素必须是一个类别的索引,该索引对应于input中对应样本的真实标签。例如,如果一个样本属于第二个类别,则对应的target为1。

交叉熵损失函数的计算公式如下:

loss = -∑(target_i * log(input_i))

通过计算预测结果和真实标签之间的交叉熵,我们可以反映模型的训练效果。交叉熵损失函数越小,说明模型的预测结果越接近真实标签,训练效果越好。

下面通过一个简单的示例来说明交叉熵损失函数的使用方法:

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个三分类模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(3, 3) # 输入特征维度为3,输出类别数为3

    def forward(self, x):
        return self.fc(x)

# 创建模型和数据
model = Net()
input = torch.randn(2, 3) # 生成一个输入样本,形状为(2,3)
target = torch.tensor([1, 2]) # 输入样本对应的真实标签

# 定义交叉熵损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 模型训练
for epoch in range(100):
    optimizer.zero_grad() # 清零梯度缓存
    output = model(input) # 输入样本获取预测结果
    loss = criterion(output, target) # 计算交叉熵损失
    loss.backward() # 反向传播
    optimizer.step() # 更新模型参数

    if (epoch + 1) % 10 == 0:
        print("Epoch {}, Loss: {:.4f}".format(epoch+1, loss.item()))

在上述示例中,我们首先定义了一个三分类模型(Net),该模型的输入特征维度为3,输出类别数为3。然后,我们生成一个输入样本input和对应的真实标签target。

接着,我们定义了交叉熵损失函数(CrossEntropyLoss)和优化器(SGD,随机梯度下降),其中学习率设置为0.01。

在模型训练过程中,我们通过循环迭代的方式,先将梯度清零,然后利用模型获得预测结果output,接着计算交叉熵损失,然后进行反向传播和参数更新。每隔10个epoch,我们打印出当前的损失值。

通过这个例子,你可以看到交叉熵损失函数在分类问题中的使用方式,并通过对输入样本进行训练,不断优化模型的分类能力。