dataset和dataloader

Dataset && DataLoader

前言:为了易读以及适应python的模块化编程,PyTorch提供了两个加载数据的原型,分别为:torch.utils.data.Dataset以及torch.utils.data.DataLoader,其中Dataset存储了数据集的样本已经相应的标签,DataLoader将其进一步进行包装成为一个迭代器使得我们可以更容易的从中获取训练样本

PyTorch库中已经封装大量的数据集(例如FashionMNIST),可以从下面两个链接中找到使用方法,这里仅对你自己定制的数据集进行解释说明。链接一链接二

Dataset

自己定制的Dataset继承于torch.utils.data.Dataset,需要实现三个函数:__init____len____getitem__

  • init:此函数在实例化Dataset类时被调用,用来初始化:包含图片的文件夹,标注文件以及transforms
  • len:返回样本的总数量
  • getitem:从给定索引index中索引一个样本,并转化为张量,通过transform对其进行处理, 最后返回张量以及相应的标签等

具体的流程是,首先我们在一个文件夹A下有训练集,然后我们将A中的图片在init初始化,之后len就是训练的总数,getitem是在训练时对其进行索引,在索引是会进行一系列的操作,例如旋转等

给出Dataset的代码示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])

class CustomDataset(Dataset):
def __init__(self, image_dir):
super(CustomDataset, self).__init__()
self.image_dir = image_dir

inp = sorted(os.listdir(os.path.join(self.image_dir)))
self.inp_filenames = [os.path.join(self.image_dir, x) for x in inp if is_image_file(x)]

def __len__(self):
return len(self.inp_filenames)

def __getitem__(self, index):
input_path = self.inp_filenames[index]
input = Image.open(input_path)
inp_img = TF.to_tensor(input)
return inp_img

DataLoader

只有Dataset是不够的,因为在训练时我们常常需要将数据集划分为Batch进行训练,并且为了防止过拟合,我们希望数据集在每个epoch都是被打乱的

使用方法也比较简单,只需要将之前得到的Dataset包装进去即可

1
2
3
4
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
for _, img in enumerate(DataLoader):
input = img
...