如果要加载自己定义的数据的话,看mxnet关于mnist基本上能够推测12
看pytorch与mxnet他们加载数据方式的对比
对于pytorch而言,它使用了find_class这样一个函数,而对于mxnet而言,实际上它在类内部定义了一个_list_images的函数,事实上我并没有发现这有没有用,只需要get_item这个函数中返回list,list中是一个tuple,一个是文件的名字,另外一个是文件所对应的label即可。
只需要继承这一个类即可
直接撸代码
这个是我参加kaggle比赛的一段代码,尽管并不收敛,但请不要在意这些细节
1 # -*-coding:utf-8-*- 2 from mxnet import autograd 3 from mxnet import gluon 4 from mxnet import image 5 from mxnet import init 6 from mxnet import nd 7 from mxnet.gluon.data import vision 8 import numpy as np 9 from mxnet.gluon.data import dataset 10 import os 11 import warnings 12 import random 13 from mxnet import gpu 14 from mxnet.gluon.data.vision import datasets 15 16 class MyImageFolderDataset(dataset.Dataset): 17 def __init__(self, root, label, flag=1, transform=None): 18 self._root = os.path.expanduser(root) 19 self._flag = flag 20 self._label = label 21 self._transform = transform 22 self._exts = ['.jpg', '.jpeg', '.png'] 23 self._list_images(self._root, self._label) 24 25 def _list_images(self, root, label): # label是一个list 26 self.synsets = [] 27 self.synsets.append(root) 28 self.items = [] 29 #file = open(label) 30 #lines = file.readlines() 31 #random.shuffle(lines) 32 c = 0 33 for line in label: 34 cls = line.split() 35 fn = cls.pop(0) 36 fn = fn + '.jpg' 37 # print(os.path.join(root, fn)) 38 if os.path.isfile(os.path.join(root, fn)): 39 self.items.append((os.path.join(root, fn), float(cls[0]))) 40 # print((os.path.join(root, fn), float(cls[0]))) 41 else: 42 print('what') 43 c = c + 1 44 print('the total image is ', c) 45 46 def __getitem__(self, idx): 47 img = image.imread(self.items[idx][0], self._flag) 48 label = self.items[idx][1] 49 if self._transform is not None: 50 return self._transform(img, label) 51 return img, label 52 53 def __len__(self): 54 return len(self.items) 55 56 57 def _get_batch(batch, ctx): # 可以在循环中直接for i, data, label,函数主要把data放在ctx上 58 """return data and label on ctx""" 59 if isinstance(batch, mx.io.DataBatch): 60 data = batch.data[0] 61 label = batch.label[0] 62 else: 63 data, label = batch 64 return (gluon.utils.split_and_load(data, ctx), 65 gluon.utils.split_and_load(label, ctx), 66 data.shape[0]) 67 68 def transform_train(data, label): 69 im = image.imresize(data.astype('float32') / 255, 256, 256) 70 auglist = image.CreateAugmenter(data_shape=(3, 256, 256), resize=0, 71 rand_crop=False, rand_resize=False, rand_mirror=True, 72 mean=None, std=None, 73 brightness=0, contrast=0, 74 saturation=0, hue=0, 75 pca_noise=0, rand_gray=0, inter_method=2) 76 for aug in auglist: 77 im = aug(im) 78 # 将数据格式从"高*宽*通道"改为"通道*高*宽"。 79 im = nd.transpose(im, (2, 0, 1)) 80 return (im, nd.array([label]).asscalar().astype('float32')) 81 82 83 def transform_test(data, label): 84 im = image.imresize(data.astype('float32') / 255, 256, 256) 85 im = nd.transpose(im, (2, 0, 1)) # 之前没有运行此变换 86 return (im, nd.array([label]).asscalar().astype('float32')) 87 88 batch_size = 16 89 root = '/home/ying/data2/shiyongjie/landmark_recognition/data/image' 90 def random_choose_data(label_path): 91 f = open(label_path) 92 lines = f.readlins() 93 random.shuffle(lines) 94 total_number = len(lines) 95 train_number = total_number/10*7 96 train_list = lines[:train_number] 97 test_list = lines[train_number:] 98 return (train_list, test_list) 99 100 label_path = '/home/ying/data2/shiyongjie/landmark_recognition/data/train.txt' 101 train_list, test_list = random_choose_data(label_path) 102 loader = gluon.data.DataLoader 103 train_ds = MyImageFolderDataset(os.path.join(root, 'image'), train_list, flag=1, transform=transform_train) 104 test_ds = MyImageFolderDataset(os.path.join(root, 'Testing'), test_list, flag=1, transform=transform_test) 105 train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep') 106 test_data = loader(test_ds, batch_size, shuffle=False, last_batch='keep') 107 softmax_cross_entropy = gluon.loss.L2Loss() # 定义L2 loss 108 109 110 from mxnet.gluon import nn 111 112 net = nn.Sequential() 113 with net.name_scope(): 114 net.add( 115 # 第一阶段 116 nn.Conv2D(channels=96, kernel_size=11, 117 strides=4, activation='relu'), 118 nn.MaxPool2D(pool_size=3, strides=2), 119 # 第二阶段 120 nn.Conv2D(channels=256, kernel_size=5, 121 padding=2, activation='relu'), 122 nn.MaxPool2D(pool_size=3, strides=2), 123 # 第三阶段 124 nn.Conv2D(channels=384, kernel_size=3, 125 padding=1, activation='relu'), 126 nn.Conv2D(channels=384, kernel_size=3, 127 padding=1, activation='relu'), 128 nn.Conv2D(channels=256, kernel_size=3, 129 padding=1, activation='relu'), 130 nn.MaxPool2D(pool_size=3, strides=2), 131 # 第四阶段 132 nn.Flatten(), 133 nn.Dense(4096, activation="relu"), 134 nn.Dropout(.5), 135 # 第五阶段 136 nn.Dense(4096, activation="relu"), 137 nn.Dropout(.5), 138 # 第六阶段 139 nn.Dense(14950) # 输出为1个值 140 ) 141 142 from mxnet import init 143 from mxnet import gluon 144 import mxnet as mx 145 import utils 146 import datetime 147 from time import time 148 149 ctx = utils.try_gpu() 150 net.initialize(ctx=ctx, init=init.Xavier()) 151 152 mse_loss = gluon.loss.L2Loss() 153 154 # utils.train(train_data, test_data, net, loss, 155 # trainer, ctx, num_epochs=10) 156 #def train(train_data, test_data, net, loss, trainer, ctx, num_epochs, print_batches=None): 157 num_epochs = 10 158 print_batches = 100 159 """Train a network""" 160 print("Start training on ", ctx) 161 if isinstance(ctx, mx.Context): 162 ctx = [ctx] 163 def train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_period, lr_decay): 164 trainer = gluon.Trainer(net.collect_params(), 'sgd', 165 {'learning_rate': lr, 'momentum': 0.9, 'wd': wd}) 166 prev_time = datetime.datetime.now() 167 for epoch in range(num_epochs): 168 train_loss = 0.0 169 if epoch > 0 and epoch % lr_period == 0: 170 trainer.set_learning_rate(trainer.learning_rate*lr_decay) 171 for data, label in train_data: 172 label = label.as_in_context(ctx) 173 with autograd.record(): 174 output = net(data.as_in_context(ctx)) 175 loss = mse_loss(output, label) 176 loss.backward() 177 trainer.step(batch_size) # do the update, Trainer needs to know the batch size of the data to normalize 178 # the gradient by 1/batch_size 179 train_loss += nd.mean(loss).asscalar() 180 print(nd.mean(loss).asscalar()) 181 cur_time = datetime.datetime.now() 182 h, remainder = divmod((cur_time - prev_time).seconds, 3600) 183 m, s = divmod(remainder, 60) 184 time_str = "Time %02d:%02d:%02d" % (h, m, s) 185 epoch_str = ('Epoch %d. Train loss: %f, ' % (epoch, train_loss / len(train_data))) 186 prev_time = cur_time 187 print(epoch_str + time_str + ', lr' + str(trainer.learning_rate)) 188 net.collect_params().save('./model/alexnet.params') 189 ctx = utils.try_gpu() 190 num_epochs = 100 191 learning_rate = 0.001 192 weight_decay = 5e-4 193 lr_period = 10 194 lr_decay = 0.1 195 196 train(net, train_data, test_data, num_epochs, learning_rate, 197 weight_decay, ctx, lr_period, lr_decay)
请看这一段
1 class MyImageFolderDataset(dataset.Dataset): 2 def __init__(self, root, label, flag=1, transform=None): 3 self._root = os.path.expanduser(root) 4 self._flag = flag 5 self._label = label 6 self._transform = transform 7 self._exts = ['.jpg', '.jpeg', '.png'] 8 self._list_images(self._root, self._label) 9 10 def _list_images(self, root, label): # label是一个list 11 self.synsets = [] 12 self.synsets.append(root) 13 self.items = [] 14 #file = open(label) 15 #lines = file.readlines() 16 #random.shuffle(lines) 17 c = 0 18 for line in label: 19 cls = line.split() 20 fn = cls.pop(0) 21 fn = fn + '.jpg' 22 # print(os.path.join(root, fn)) 23 if os.path.isfile(os.path.join(root, fn)): 24 self.items.append((os.path.join(root, fn), float(cls[0]))) 25 # print((os.path.join(root, fn), float(cls[0]))) 26 else: 27 print('what') 28 c = c + 1 29 print('the total image is ', c) 30 31 def __getitem__(self, idx): 32 img = image.imread(self.items[idx][0], self._flag) 33 label = self.items[idx][1] 34 if self._transform is not None: 35 return self._transform(img, label) 36 return img, label 37 38 def __len__(self): 39 return len(self.items) 40 batch_size = 16 41 root = '/home/ying/data2/shiyongjie/landmark_recognition/data/image' 42 def random_choose_data(label_path): 43 f = open(label_path) 44 lines = f.readlins() 45 random.shuffle(lines) 46 total_number = len(lines) 47 train_number = total_number/10*7 48 train_list = lines[:train_number] 49 test_list = lines[train_number:] 50 return (train_list, test_list) 51 52 label_path = '/home/ying/data2/shiyongjie/landmark_recognition/data/train.txt' 53 train_list, test_list = random_choose_data(label_path) 54 55 loader = gluon.data.DataLoader 56 train_ds = MyImageFolderDataset(os.path.join(root, 'image'), train_list, flag=1, transform=transform_train) 57 test_ds = MyImageFolderDataset(os.path.join(root, 'Testing'), test_list, flag=1, transform=transform_test) 58 train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep') 59 test_data = loader(test_ds, batch_size, shuffle=False, last_batch='keep')
MyImageFolderDataset是dataset.Dataset的子类,主要是是重载索引运算__getitem__,并且返回image以及其对应的label即可,前面的的_list_image函数只要是能够返回item这个list就行,关于运算符重载给自己挖个坑
可以说和pytorch非常像了,就连沐神在讲课的时候还在说,其实在写mxnet的时候,借鉴了很多pytorch的内容