在加载数据集的时候我们需要对读入的图片进行归一化处理,在pytorch里使用torchvision中的transform来对其进行处理,这里不介绍旋转,裁剪等操作,进介绍归一化操作,会用到下面两个函数
transforms.ToTensor()
transforms.Normalize()
一般处理图片时有两个操作,第一步将其归一化为0-1之间,第二步在使用Normalize进行归一化
ToTensor
这一步很简单,将图片归一化到[0, 1]之间,即将图片像素max(即255)除上255即可,同时将HWC转化为CHW。如下所示,将1除上255即的0.0039
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 data1 = np.array([[[1 , 1 , 1 ], [2 , 2 , 2 ]], [[3 , 3 , 3 ], [4 , 4 , 255 ]]], dtype='uint8' ) data2 = transforms.ToTensor()(data) ''' data1 --> (2,2,3) [[[ 1 1 1] [ 2 2 2]] [[ 3 3 3] [ 4 4 255]]] data2 --> (3, 2, 2) tensor([[[0.0039, 0.0078], # 第一维度R [0.0118, 0.0157]], [[0.0039, 0.0078], # 第二维度G [0.0118, 0.0157]], [[0.0039, 0.0078], # 第三维度B [0.0118, 1.0000]]]) '''
Normalize
需要给Normalize指定参数,Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
其中((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),第一个括号内分别为RGB通道的均值,第二个括号内分别为RGB通道的方差 ,首先我们手动计算一下三个通道的均值和方差,由于上述例子不难,我们直接能得到均值和方差为别为(2.5, 2.5, 65.25),(1.118, 1.118, 109.55)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 data1 = np.array([[[1 , 1 , 1 ], [2 , 2 , 2 ]], [[3 , 3 , 3 ], [4 , 4 , 255 ]]], dtype='uint8' ) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((2.5 , 2.5 , 65.25 ), (1.118 , 1.118 , 109.55 ))]) data2 = transform(data1) ''' data2 tensor([[[-2.2326, -2.2291], [-2.2256, -2.2221]], [[-2.2326, -2.2291], [-2.2256, -2.2221]], [[-0.5956, -0.5955], [-0.5955, -0.5865]]]) '''
为了验证上述结果我们这里手动计算一下结果,可以发现与上述结果相同
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 data = data/255 data1 = (data[:,:,0 ]-2.5 )/1.118 data2 = (data[:,:,1 ]-2.5 )/1.118 data3 = (data[:,:,2 ]-65.25 )/109.55 ''' 1,2,3,4 --> [[-2.23262829 -2.22912063] [-2.22561296 -2.2221053 ]] 1,2,3,4 --> [[-2.23262829 -2.22912063] [-2.22561296 -2.2221053 ]] 1,2,3,255 --> [[-0.59558264 -0.59554684] [-0.59551105 -0.58649019]] '''
这里就会产生一个问题,为什么最终结果没有归一化到[-1, 1]之间呢,这是因为我们最终通过 归一化的结果为标准正态分布,如果要归一化到[-1, 1]需要将均值设为0.5,方差为0.5才可以,这是因为第一步ToTensor归一化到[0, 1]之后,减去0.5除上0.5就能够到[-1, 1],我们只需要改变Normalize的参数,transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
下面经过试验可以看到确实范围在[-1, 1]了
1 2 3 4 5 6 7 8 9 10 11 12 13 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]) ''' tensor([[[-0.9922, -0.9843], [-0.9765, -0.9686]], [[-0.9922, -0.9843], [-0.9765, -0.9686]], [[-0.9922, -0.9843], [-0.9765, -0.9686]]]) '''