torch.squeeze

torch.squeeze()以及torch.unsqueeze()函数分别是给目标tensor去掉维度只有1的那一维或者给目标的某一维度添加一维,对应的有两个in-place操作Tensor.squeeze_()Tensor.unsqueeze_()

一、torch.squeeze()

torch.squeeze(input, dim=None)

  • input ([Tensor] – the input tensor.
  • dim ([int], optional) – if given, the input will be squeezed only in this dimension

将输入的tensor中维度为1的那一维去除

Returns a tensor with all the dimensions of input of size 1 removed.

For example, if input is of shape: (), then the out tensor will be of shape: ().

When dim is given, a squeeze operation is done only in the given dimension. If input is of shape: (), squeeze(input, 0) leaves the tensor unchanged, but squeeze(input, 1) will squeeze the tensor to the shape ().

warning:If the tensor has a batch dimension of size 1, then squeeze(input) will also remove the batch dimension, which can lead to unexpected errors.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
x = torch.zeros(2, 1, 2, 1, 2)
print(x.size())
'''
torch.Size([2, 1, 2, 1, 2])
'''


y = torch.squeeze(x)
print("torch.squeeze(x): ", y.size())
y = torch.squeeze(x, 0)
print("torch.squeeze(x, 0): ", y.size())
y = torch.squeeze(x, 1)
print("torch.squeeze(x, 1): ", y.size())
'''
torch.squeeze(x): torch.Size([2, 2, 2])
torch.squeeze(x, 0): torch.Size([2, 1, 2, 1, 2])
torch.squeeze(x, 1): torch.Size([2, 2, 1, 2])
'''


''' in-place version '''
# modify x in origin storage
print(x.size())
x.squeeze_()
print(x.size())
'''
torch.Size([2, 1, 2, 1, 2])
torch.Size([2, 2, 2])
'''

二、torch.unsqueeze()

torch.unsqueeze(input, dim=None)

  • input ([Tensor] – the input tensor.
  • dim ([int], optional) – the index at which to insert the singleton dimension

将tensor的某一维增加1

Returns a new tensor with a dimension of size one inserted at the specified position.

The returned tensor shares the same underlying data with this tensor.

A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to [unsqueeze()] applied at dim = dim + input.dim() + 1.

维度:dim的取值范围为,负的维度会被映射到正的维度上,即if dim<0: dim = dim+input.dim()+1

举个例子,加入输入为二维:(3, 3)

dim的范围为[-3, 3),取值为-3, -2, -1, 0, 1, 2

dim = -3, 0时,输出:(1, 3, 3)

dim = -2, 1时,输出:(3, 1, 3)

dim = -1, 2时,输出:(3, 3, 1)

x = torch.tensor([1, 2, 3, 4])
print("x:                    ", x.size())
y1 = torch.unsqueeze(x, 0)
print("torch.squeeze(x, 0):  ", y1.size())
y2 = torch.unsqueeze(x, 1)
print("torch.squeeze(x, 1):  ", y2.size())
'''
x:                     torch.Size([4])
torch.squeeze(x, 0):   torch.Size([1, 4])
torch.squeeze(x, 1):   torch.Size([4, 1])
'''


print(x)
print(y1)
print(y2)
'''
tensor([1, 2, 3, 4])
tensor([[1, 2, 3, 4]])
tensor([[1],
        [2],
        [3],
        [4]])
'''


''' in-place version '''
# modify x in origin storage
print(x.size())
x.unsqueeze_(0)
print(x.size())
'''
torch.Size([3, 3])
torch.Size([1, 3, 3])
'''