PyTorch中关于gather()函数的详细解析
发布时间:2023-12-26 04:26:51
在PyTorch中,gather()函数用于从输入张量中收集指定的元素。它有两个参数, 个参数是输入张量,第二个参数是指定要收集元素的索引。
gather()函数可以用于多种场景,例如:
1. 对一个张量进行索引取值,选取指定位置的元素。
2. 根据指定的索引获取对应的元素,可用于实现类似于取最大值的操作。
下面我们将分别介绍这两种应用场景,并给出对应的使用示例。
1. 索引取值
当要从一个张量中获取指定位置的元素时,可以使用gather()函数。在这种情况下,输入张量必须为一维张量。
import torch # 创建输入张量 input_tensor = torch.tensor([1, 2, 3, 4, 5]) # 创建索引张量 index_tensor = torch.tensor([0, 2, 4]) # 使用gather()函数从输入张量中获取指定位置的元素 output_tensor = torch.gather(input_tensor, 0, index_tensor) print(output_tensor)
输出结果为:
tensor([1, 3, 5])
在这个例子中,输入张量input_tensor为一维张量[1, 2, 3, 4, 5],索引张量index_tensor为一维张量[0, 2, 4]。使用gather()函数从输入张量中根据索引张量获取指定位置的元素,输出结果为[1, 3, 5]。
2. 根据索引获取元素
当要根据指定的索引获取对应的元素时,可以使用gather()函数。在这种情况下,输入张量可以为任意维度。
import torch # 创建输入张量 input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # 创建索引张量 index_tensor = torch.tensor([[0, 2], [1, 0], [2, 1]]) # 使用gather()函数根据索引获取对应的元素 output_tensor = torch.gather(input_tensor, 1, index_tensor) print(output_tensor)
输出结果为:
tensor([[1, 3],
[5, 4],
[9, 8]])
在这个例子中,输入张量input_tensor为二维张量[[1, 2, 3], [4, 5, 6], [7, 8, 9]],索引张量index_tensor为二维张量[[0, 2], [1, 0], [2, 1]]。使用gather()函数根据索引张量从输入张量中获取对应的元素,输出结果为[[1, 3], [5, 4], [9, 8]]。
总结来说,gather()函数可以用于从输入张量中收集指定的元素。它可以用于索引取值和根据索引获取元素的场景。在使用时,需要注意输入张量的维度和索引张量的形状需要匹配。
