import os import gzip import numpy as np from torch.utils.data import Dataset
''' load data - data_folder: MNIST folder name - data_name: MNIST data name - label_name: MNIST lable name ''' defload_data(data_folder, data_name, label_name): with gzip.open(os.path.join(data_folder, label_name), 'rb') as lbpath: # rb表示的是读取二进制数据 y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(os.path.join(data_folder, data_name), 'rb') as imgpath: x_train = np.frombuffer( imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28) return (x_train, y_train)