gather

函数形式

1
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

作用

沿着由dim指定的轴收集数值

输出结果

1
2
3
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

举例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1],
[ 4, 3]])
>>> t = torch.tensor([[1,2],[3,4], [6, 8]])
>>> t
tensor([[1, 2],
[3, 4],
[6, 8]])
>>> torch.gather(t, 1, torch.tensor([[1, 0],[0,2]]))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: index 2 is out of bounds for dimension 1 with size 2
>>> torch.gather(t, 0, torch.tensor([[1, 0],[0,2], [2, 1]]))
tensor([[3, 2],
[1, 8],
[6, 4]])
>>> torch.gather(t, 1, torch.tensor([[1, 0],[0,1], [1, 1]]))
tensor([[2, 1],
[3, 4],
[8, 8]])
>>> torch.gather(t, 1, torch.tensor([[1, 0, 0],[0,0,1], [0, 1, 1]]))
tensor([[2, 1, 1],
[3, 3, 4],
[6, 8, 8]])

从下图可以看到,当dim=0时,output每列是index每列沿着从上到下的顺序索引input对应列的值填充;当dim=1是,output每行是index每行沿着从左到右的顺序索引input对应行的值填充。

img

总结

  • output的形状与index的一致
  • index中数字范围为[0, 轴方向数据个数-1]
  • index的维数与input一致

scatter_

函数形式

1
Tensor.scatter_(dim, index, src, reduce=None) → Tensor

作用

将src中数据根据index中的索引按照dim的方向填进张量中。

输出结果

1
2
3
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

举例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])

>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
[2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 3.2300]])

img

总结