欢迎访问宙启技术站
智能推送

PyTorch中torch.autograd的常用API和操作示例

发布时间:2024-01-03 06:09:16

在PyTorch中,torch.autograd是用于计算梯度的自动微分引擎。它提供了一组API和操作,用于定义和计算张量的梯度。

1. torch.autograd.grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False)

此函数用于计算outputs相对于inputs的梯度。输入参数包括:

- outputs:一个张量或张量元组,对于该张量计算梯度。

- inputs:一个张量或张量元组,对于这些张量计算梯度。

- grad_outputs:用于对outputs的梯度进行加权求和的张量。默认情况下,使用梯度为1。

- retain_graph:一个布尔值,指示是否在计算梯度后保留计算图。默认为False。

- create_graph:一个布尔值,指示是否在计算梯度后创建计算图,以便再次计算二阶梯度。默认为False。

- only_inputs:一个布尔值,指示是否仅对inputs计算梯度。默认为True。

- allow_unused:一个布尔值,指示是否允许未使用的输入张量。默认为False。

示例:

import torch

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0)
z = torch.tensor(4.0)

f = x * y + z

grad_x = torch.autograd.grad(f, x)
print(grad_x)

输出为:(tensor(3.),)

2. torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph=False, inputs=None)

该函数用于计算tensors的梯度。输入参数包括:

- tensors:一个张量或张量元组,对应于这些张量计算梯度。

- grad_tensors:对于输出张量的梯度张量。默认情况下,使用梯度为1。

- retain_graph:一个布尔值,指示是否在计算梯度后保留计算图。默认为False。

- create_graph:一个布尔值,指示是否在计算梯度后创建计算图,以计算二阶梯度。默认为False。

- inputs:一个张量或张量元组,对于这些张量计算梯度。

示例:

import torch

x = torch.tensor(2.0, requires_grad=True)
y = torch.tensor(3.0)
z = torch.tensor(4.0)

f = x * y + z

torch.autograd.backward(f, retain_graph=True)

grad_x = x.grad
print(grad_x)

输出为:tensor(3.)

3. torch.autograd.gradcheck(func, inputs, eps=1e-6, atol=1e-4, rtol=1e-4, raise_exception=True)

此函数用于检查梯度计算是否正确。输入参数包括:

- func:一个函数或模块,对于该函数或模块计算梯度。

- inputs:一个张量或张量元组,对于这些张量计算梯度。

- eps:一个浮点数,表示数值近似的阈值。默认为1e-6。

- atol:一个浮点数,表示绝对误差的阈值。默认为1e-4。

- rtol:一个浮点数,表示相对误差的阈值。默认为1e-4。

- raise_exception:一个布尔值,指示是否在检查失败时引发异常。默认为True。

示例:

import torch

def func(x):
    return x**3 + 2 * x**2 + x

x = torch.randn(3, requires_grad=True)

torch.autograd.gradcheck(func, x)

输出为:True

除了上述常用API,torch.autograd还提供了计算梯度的其他基本功能,如动态计算图和梯度跟踪等。在实际使用中,可以根据需求选择适当的API和操作来计算和管理梯度。