PyTorch中torch.autograd的常用API和操作示例
在PyTorch中,torch.autograd是用于计算梯度的自动微分引擎。它提供了一组API和操作,用于定义和计算张量的梯度。
1. torch.autograd.grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)
此函数用于计算outputs相对于inputs的梯度。输入参数包括:
- outputs:一个张量或张量元组,对于该张量计算梯度。
- inputs:一个张量或张量元组,对于这些张量计算梯度。
- grad_outputs:用于对outputs的梯度进行加权求和的张量。默认情况下,使用梯度为1。
- retain_graph:一个布尔值,指示是否在计算梯度后保留计算图。默认为False。
- create_graph:一个布尔值,指示是否在计算梯度后创建计算图,以便再次计算二阶梯度。默认为False。
- only_inputs:一个布尔值,指示是否仅对inputs计算梯度。默认为True。
- allow_unused:一个布尔值,指示是否允许未使用的输入张量。默认为False。
示例:
import torch x = torch.tensor(2.0, requires_grad=True) y = torch.tensor(3.0) z = torch.tensor(4.0) f = x * y + z grad_x = torch.autograd.grad(f, x) print(grad_x)
输出为:(tensor(3.),)
2. torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, inputs=None)
该函数用于计算tensors的梯度。输入参数包括:
- tensors:一个张量或张量元组,对应于这些张量计算梯度。
- grad_tensors:对于输出张量的梯度张量。默认情况下,使用梯度为1。
- retain_graph:一个布尔值,指示是否在计算梯度后保留计算图。默认为False。
- create_graph:一个布尔值,指示是否在计算梯度后创建计算图,以计算二阶梯度。默认为False。
- inputs:一个张量或张量元组,对于这些张量计算梯度。
示例:
import torch x = torch.tensor(2.0, requires_grad=True) y = torch.tensor(3.0) z = torch.tensor(4.0) f = x * y + z torch.autograd.backward(f, retain_graph=True) grad_x = x.grad print(grad_x)
输出为:tensor(3.)
3. torch.autograd.gradcheck(func, inputs, eps=1e-6, atol=1e-4, rtol=1e-4, raise_exception=True)
此函数用于检查梯度计算是否正确。输入参数包括:
- func:一个函数或模块,对于该函数或模块计算梯度。
- inputs:一个张量或张量元组,对于这些张量计算梯度。
- eps:一个浮点数,表示数值近似的阈值。默认为1e-6。
- atol:一个浮点数,表示绝对误差的阈值。默认为1e-4。
- rtol:一个浮点数,表示相对误差的阈值。默认为1e-4。
- raise_exception:一个布尔值,指示是否在检查失败时引发异常。默认为True。
示例:
import torch
def func(x):
return x**3 + 2 * x**2 + x
x = torch.randn(3, requires_grad=True)
torch.autograd.gradcheck(func, x)
输出为:True
除了上述常用API,torch.autograd还提供了计算梯度的其他基本功能,如动态计算图和梯度跟踪等。在实际使用中,可以根据需求选择适当的API和操作来计算和管理梯度。
