欢迎访问宙启技术站
智能推送

PyTorch中关于gather()函数的详细解析

发布时间:2023-12-26 04:26:51

在PyTorch中,gather()函数用于从输入张量中收集指定的元素。它有两个参数, 个参数是输入张量,第二个参数是指定要收集元素的索引。

gather()函数可以用于多种场景,例如:

1. 对一个张量进行索引取值,选取指定位置的元素。

2. 根据指定的索引获取对应的元素,可用于实现类似于取最大值的操作。

下面我们将分别介绍这两种应用场景,并给出对应的使用示例。

1. 索引取值

当要从一个张量中获取指定位置的元素时,可以使用gather()函数。在这种情况下,输入张量必须为一维张量。

import torch

# 创建输入张量
input_tensor = torch.tensor([1, 2, 3, 4, 5])

# 创建索引张量
index_tensor = torch.tensor([0, 2, 4])

# 使用gather()函数从输入张量中获取指定位置的元素
output_tensor = torch.gather(input_tensor, 0, index_tensor)

print(output_tensor)

输出结果为:

tensor([1, 3, 5])

在这个例子中,输入张量input_tensor为一维张量[1, 2, 3, 4, 5],索引张量index_tensor为一维张量[0, 2, 4]。使用gather()函数从输入张量中根据索引张量获取指定位置的元素,输出结果为[1, 3, 5]

2. 根据索引获取元素

当要根据指定的索引获取对应的元素时,可以使用gather()函数。在这种情况下,输入张量可以为任意维度。

import torch

# 创建输入张量
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 创建索引张量
index_tensor = torch.tensor([[0, 2], [1, 0], [2, 1]])

# 使用gather()函数根据索引获取对应的元素
output_tensor = torch.gather(input_tensor, 1, index_tensor)

print(output_tensor)

输出结果为:

tensor([[1, 3],
        [5, 4],
        [9, 8]])

在这个例子中,输入张量input_tensor为二维张量[[1, 2, 3], [4, 5, 6], [7, 8, 9]],索引张量index_tensor为二维张量[[0, 2], [1, 0], [2, 1]]。使用gather()函数根据索引张量从输入张量中获取对应的元素,输出结果为[[1, 3], [5, 4], [9, 8]]

总结来说,gather()函数可以用于从输入张量中收集指定的元素。它可以用于索引取值和根据索引获取元素的场景。在使用时,需要注意输入张量的维度和索引张量的形状需要匹配。