torch.repeat_interleave
官方文档: pytorch
torch.repeat_interleave(input, repeats, dim=None, *, output_size=None)
- input: 输入tensor
- repeats: 每个元素的重复次数
- dim: 重复的维度,默认情况,把输入张量展平为向量,然后将每个元素重复repeats次
- 默认的,不给dim参数,将tensor展平
1 2 3 4 5 6 7 8 9
| x = torch.tensor([1, 2, 3]) torch.repeat_interleave(x, 2) y = torch.tensor([[1, 2], [3, 4]]) torch.repeat_interleave(y, 2)
''' tensor([1, 1, 2, 2, 3, 3]) tensor([1, 1, 2, 2, 3, 3, 4, 4]) '''
|
- 给定dim,会在给定的维度将其展平重复,元素是逐个重复的,如下,对1重复三次,2重复三次,以此类推
1 2 3 4 5 6 7
| y = torch.tensor([[1, 2], [3, 4]]) torch.repeat_interleave(y, 3, dim=1)
''' tensor([[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]]) '''
|
- 对不同元素指定重复次数
1 2 3 4 5 6 7 8
| y = torch.tensor([[1, 2], [3, 4]]) torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
''' tensor([[1, 2], [3, 4], [3, 4]]) '''
|
- 此函数也支持tensor.repeat_interleave()的形式
1 2 3 4 5 6
| x = torch.tensor([1, 2, 3]) x.repeat_interleave(2)
''' tensor([1, 1, 2, 2, 3, 3]) '''
|