torch.squeeze()以及torch.unsqueeze()函数分别是给目标tensor去掉维度只有1的那一维或者给目标的某一维度添加一维,对应的有两个in-place操作Tensor.squeeze_()和Tensor.unsqueeze_()
一、torch.squeeze()
torch.squeeze(input, dim=None)
- input ([Tensor] – the input tensor.
- dim ([int], optional) – if given, the input will be squeezed only in this dimension
将输入的tensor中维度为1的那一维去除
Returns a tensor with all the dimensions of
inputof size 1 removed.For example, if input is of shape: (
), then the out tensor will be of shape: (
).
When
dimis given, a squeeze operation is done only in the given dimension. If input is of shape: (),
squeeze(input, 0)leaves the tensor unchanged, butsqueeze(input, 1)will squeeze the tensor to the shape ().
warning:If the tensor has a batch dimension of size 1, then squeeze(input) will also remove the batch dimension, which can lead to unexpected errors.
1 | x = torch.zeros(2, 1, 2, 1, 2) |
二、torch.unsqueeze()
torch.unsqueeze(input, dim=None)
- input ([Tensor] – the input tensor.
- dim ([int], optional) – the index at which to insert the singleton dimension
将tensor的某一维增加1
Returns a new tensor with a dimension of size one inserted at the specified position.
The returned tensor shares the same underlying data with this tensor.
A
dimvalue within the range[-input.dim() - 1, input.dim() + 1)can be used. Negativedimwill correspond to [unsqueeze()] applied atdim=dim + input.dim() + 1.
维度:dim的取值范围为,负的维度会被映射到正的维度上,即if dim<0: dim = dim+input.dim()+1
举个例子,加入输入为二维:(3, 3)
dim的范围为[-3, 3),取值为-3, -2, -1, 0, 1, 2
dim = -3, 0时,输出:(1, 3, 3)
dim = -2, 1时,输出:(3, 1, 3)
dim = -1, 2时,输出:(3, 3, 1)
x = torch.tensor([1, 2, 3, 4])
print("x: ", x.size())
y1 = torch.unsqueeze(x, 0)
print("torch.squeeze(x, 0): ", y1.size())
y2 = torch.unsqueeze(x, 1)
print("torch.squeeze(x, 1): ", y2.size())
'''
x: torch.Size([4])
torch.squeeze(x, 0): torch.Size([1, 4])
torch.squeeze(x, 1): torch.Size([4, 1])
'''
print(x)
print(y1)
print(y2)
'''
tensor([1, 2, 3, 4])
tensor([[1, 2, 3, 4]])
tensor([[1],
[2],
[3],
[4]])
'''
''' in-place version '''
# modify x in origin storage
print(x.size())
x.unsqueeze_(0)
print(x.size())
'''
torch.Size([3, 3])
torch.Size([1, 3, 3])
'''