pytorch读取MNIST

MNIST数据集一般有两种使用方法,其中一种在torchvision中已经包装好了,这里讲解手动加载MNIST数据集的方法

下载

首先在官网下载MNIST数据集,地址,一共有四个压缩包,下载后解压即可

读取数据

复制下面代码到readdata.py中,然后给定数据集路径读取即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import gzip
import numpy as np
from torch.utils.data import Dataset


'''
load data
- data_folder: MNIST folder name
- data_name: MNIST data name
- label_name: MNIST lable name
'''
def load_data(data_folder, data_name, label_name):
with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath: # rb表示的是读取二进制数据
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
return (x_train, y_train)


class CustomDataset(Dataset):
"""
读取数据、初始化数据
"""
def __init__(self, folder, data_name, label_name,transform=None):
(train_set, train_labels) = load_data(folder, data_name, label_name) # 其实也可以直接使用torch.load(),读取之后的结果为torch.Tensor形式
self.train_set = train_set
self.train_labels = train_labels
self.transform = transform

def __getitem__(self, index):

img, target = np.array(self.train_set[index]), int(self.train_labels[index])
if self.transform is not None:
img = self.transform(img)
return img, target

def __len__(self):
return len(self.train_set)
Error: API rate limit exceeded for 54.162.207.42. (But here's the good news: Authenticated requests get a higher rate limit. Check out the documentation for more details.)