torch.repeat_interleave

torch.repeat_interleave

官方文档: pytorch

torch.repeat_interleave(input, repeats, dim=None, *, output_size=None)

  • input: 输入tensor
  • repeats: 每个元素的重复次数
  • dim: 重复的维度,默认情况,把输入张量展平为向量,然后将每个元素重复repeats次
  1. 默认的,不给dim参数,将tensor展平
1
2
3
4
5
6
7
8
9
x = torch.tensor([1, 2, 3])
torch.repeat_interleave(x, 2)
y = torch.tensor([[1, 2], [3, 4]])
torch.repeat_interleave(y, 2)

'''
tensor([1, 1, 2, 2, 3, 3])
tensor([1, 1, 2, 2, 3, 3, 4, 4])
'''
  1. 给定dim,会在给定的维度将其展平重复,元素是逐个重复的,如下,对1重复三次,2重复三次,以此类推
1
2
3
4
5
6
7
y = torch.tensor([[1, 2], [3, 4]])
torch.repeat_interleave(y, 3, dim=1)

'''
tensor([[1, 1, 1, 2, 2, 2],
[3, 3, 3, 4, 4, 4]])
'''
  1. 对不同元素指定重复次数
1
2
3
4
5
6
7
8
y = torch.tensor([[1, 2], [3, 4]])
torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)

'''
tensor([[1, 2],
[3, 4],
[3, 4]])
'''
  1. 此函数也支持tensor.repeat_interleave()的形式
1
2
3
4
5
6
x = torch.tensor([1, 2, 3])
x.repeat_interleave(2)

'''
tensor([1, 1, 2, 2, 3, 3])
'''