mxnet自定义dataloader加载自己的数据

实际上关于pytorch加载自己的数据之前有写过一篇博客,但是最近接触了mxnet,发现关于这方面的教程很少

如果要加载自己定义的数据的话,看mxnet关于mnist基本上能够推测12

看pytorch与mxnet他们加载数据方式的对比

上图左边是pytorch的,右图是mxnet

实际上,mxnet与pytorch他们的datalayer有着相似之处,为什么这样说呢?直接看上面的代码,基本上都是输入图像的路径,然后输出一个可以供loader调用的可以迭代的对象,所以无论是pytorch或者是mxnet,如果要有自己的数据,只需要在自己的数据那一部分继承与修改ImageFolderDataset这个函数就行,就是直接继承dataset.Dataset类即可

对于pytorch而言,它使用了find_class这样一个函数,而对于mxnet而言,实际上它在类内部定义了一个_list_images的函数,事实上我并没有发现这有没有用,只需要get_item这个函数中返回list,list中是一个tuple,一个是文件的名字,另外一个是文件所对应的label即可。

只需要继承这一个类即可

直接撸代码

这个是我参加kaggle比赛的一段代码,尽管并不收敛,但请不要在意这些细节

  1 # -*-coding:utf-8-*-
  2 from mxnet import autograd
  3 from mxnet import gluon
  4 from mxnet import image
  5 from mxnet import init
  6 from mxnet import nd
  7 from mxnet.gluon.data import vision
  8 import numpy as np
  9 from mxnet.gluon.data import dataset
 10 import os
 11 import warnings
 12 import random
 13 from mxnet import gpu
 14 from mxnet.gluon.data.vision import datasets
 15 
 16 class MyImageFolderDataset(dataset.Dataset):
 17     def __init__(self, root, label, flag=1, transform=None):
 18         self._root = os.path.expanduser(root)
 19         self._flag = flag
 20         self._label = label
 21         self._transform = transform
 22         self._exts = ['.jpg', '.jpeg', '.png']
 23         self._list_images(self._root, self._label)
 24 
 25     def _list_images(self, root, label):  # label是一个list
 26         self.synsets = []
 27         self.synsets.append(root)
 28         self.items = []
 29         #file = open(label)
 30         #lines = file.readlines()
 31         #random.shuffle(lines)
 32         c = 0
 33         for line in label:
 34             cls = line.split()
 35             fn = cls.pop(0)
 36             fn = fn + '.jpg'
 37             # print(os.path.join(root, fn))
 38             if os.path.isfile(os.path.join(root, fn)):
 39                 self.items.append((os.path.join(root, fn), float(cls[0])))
 40                 # print((os.path.join(root, fn), float(cls[0])))
 41             else:
 42                 print('what')
 43             c = c + 1
 44         print('the total image is ', c)
 45 
 46     def __getitem__(self, idx):
 47         img = image.imread(self.items[idx][0], self._flag)
 48         label = self.items[idx][1]
 49         if self._transform is not None:
 50             return self._transform(img, label)
 51         return img, label
 52 
 53     def __len__(self):
 54         return len(self.items)
 55 
 56 
 57 def _get_batch(batch, ctx):  # 可以在循环中直接for i, data, label,函数主要把data放在ctx上
 58     """return data and label on ctx"""
 59     if isinstance(batch, mx.io.DataBatch):
 60         data = batch.data[0]
 61         label = batch.label[0]
 62     else:
 63         data, label = batch
 64     return (gluon.utils.split_and_load(data, ctx),
 65             gluon.utils.split_and_load(label, ctx),
 66             data.shape[0])
 67 
 68 def transform_train(data, label):
 69     im = image.imresize(data.astype('float32') / 255, 256, 256)
 70     auglist = image.CreateAugmenter(data_shape=(3, 256, 256), resize=0,
 71                         rand_crop=False, rand_resize=False, rand_mirror=True,
 72                         mean=None, std=None,
 73                         brightness=0, contrast=0,
 74                         saturation=0, hue=0,
 75                         pca_noise=0, rand_gray=0, inter_method=2)
 76     for aug in auglist:
 77         im = aug(im)
 78     # 将数据格式从"高*宽*通道"改为"通道*高*宽"。
 79     im = nd.transpose(im, (2, 0, 1))
 80     return (im, nd.array([label]).asscalar().astype('float32'))
 81 
 82 
 83 def transform_test(data, label):
 84     im = image.imresize(data.astype('float32') / 255, 256, 256)
 85     im = nd.transpose(im, (2, 0, 1))  # 之前没有运行此变换
 86     return (im, nd.array([label]).asscalar().astype('float32'))
 87 
 88 batch_size = 16
 89 root = '/home/ying/data2/shiyongjie/landmark_recognition/data/image'
 90 def random_choose_data(label_path):
 91     f = open(label_path)
 92     lines = f.readlins()
 93     random.shuffle(lines)
 94     total_number = len(lines)
 95     train_number = total_number/10*7
 96     train_list = lines[:train_number]
 97     test_list = lines[train_number:]
 98     return (train_list, test_list)
 99 
100 label_path = '/home/ying/data2/shiyongjie/landmark_recognition/data/train.txt'
101 train_list, test_list = random_choose_data(label_path)
102 loader = gluon.data.DataLoader
103 train_ds = MyImageFolderDataset(os.path.join(root, 'image'), train_list, flag=1, transform=transform_train)
104 test_ds = MyImageFolderDataset(os.path.join(root, 'Testing'), test_list, flag=1, transform=transform_test)
105 train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep')
106 test_data = loader(test_ds, batch_size, shuffle=False, last_batch='keep')
107 softmax_cross_entropy = gluon.loss.L2Loss()  # 定义L2 loss
108 
109 
110 from mxnet.gluon import nn
111 
112 net = nn.Sequential()
113 with net.name_scope():
114     net.add(
115         # 第一阶段
116         nn.Conv2D(channels=96, kernel_size=11,
117                   strides=4, activation='relu'),
118         nn.MaxPool2D(pool_size=3, strides=2),
119         # 第二阶段
120         nn.Conv2D(channels=256, kernel_size=5,
121                   padding=2, activation='relu'),
122         nn.MaxPool2D(pool_size=3, strides=2),
123         # 第三阶段
124         nn.Conv2D(channels=384, kernel_size=3,
125                   padding=1, activation='relu'),
126         nn.Conv2D(channels=384, kernel_size=3,
127                   padding=1, activation='relu'),
128         nn.Conv2D(channels=256, kernel_size=3,
129                   padding=1, activation='relu'),
130         nn.MaxPool2D(pool_size=3, strides=2),
131         # 第四阶段
132         nn.Flatten(),
133         nn.Dense(4096, activation="relu"),
134         nn.Dropout(.5),
135         # 第五阶段
136         nn.Dense(4096, activation="relu"),
137         nn.Dropout(.5),
138         # 第六阶段
139         nn.Dense(14950)  # 输出为1个值
140     )
141 
142 from mxnet import init
143 from mxnet import gluon
144 import mxnet as mx
145 import utils
146 import datetime
147 from time import time
148 
149 ctx = utils.try_gpu()
150 net.initialize(ctx=ctx, init=init.Xavier())
151 
152 mse_loss = gluon.loss.L2Loss()
153 
154 # utils.train(train_data, test_data, net, loss,
155 #             trainer, ctx, num_epochs=10)
156 #def train(train_data, test_data, net, loss, trainer, ctx, num_epochs, print_batches=None):
157 num_epochs = 10
158 print_batches = 100
159 """Train a network"""
160 print("Start training on ", ctx)
161 if isinstance(ctx, mx.Context):
162     ctx = [ctx]
163 def train(net, train_data, valid_data, num_epochs, lr, wd, ctx, lr_period, lr_decay):
164     trainer = gluon.Trainer(net.collect_params(), 'sgd',
165                             {'learning_rate': lr, 'momentum': 0.9, 'wd': wd})
166     prev_time = datetime.datetime.now()
167     for epoch in range(num_epochs):
168         train_loss = 0.0
169         if epoch > 0 and epoch % lr_period == 0:
170             trainer.set_learning_rate(trainer.learning_rate*lr_decay)
171         for data, label in train_data:
172             label = label.as_in_context(ctx)
173             with autograd.record():
174                 output = net(data.as_in_context(ctx))
175                 loss = mse_loss(output, label)
176             loss.backward()
177             trainer.step(batch_size)  # do the update, Trainer needs to know the batch size of the data to normalize
178             # the gradient by 1/batch_size
179             train_loss += nd.mean(loss).asscalar()
180             print(nd.mean(loss).asscalar())
181         cur_time = datetime.datetime.now()
182         h, remainder = divmod((cur_time - prev_time).seconds, 3600)
183         m, s = divmod(remainder, 60)
184         time_str = "Time %02d:%02d:%02d" % (h, m, s)
185         epoch_str = ('Epoch %d. Train loss: %f, ' % (epoch, train_loss / len(train_data)))
186         prev_time = cur_time
187         print(epoch_str + time_str + ', lr' + str(trainer.learning_rate))
188     net.collect_params().save('./model/alexnet.params')
189 ctx = utils.try_gpu()
190 num_epochs = 100
191 learning_rate = 0.001
192 weight_decay = 5e-4
193 lr_period = 10
194 lr_decay = 0.1
195 
196 train(net, train_data, test_data, num_epochs, learning_rate,
197       weight_decay, ctx, lr_period, lr_decay)
View Code

 

请看这一段

 1 class MyImageFolderDataset(dataset.Dataset):
 2     def __init__(self, root, label, flag=1, transform=None):
 3         self._root = os.path.expanduser(root)
 4         self._flag = flag
 5         self._label = label
 6         self._transform = transform
 7         self._exts = ['.jpg', '.jpeg', '.png']
 8         self._list_images(self._root, self._label)
 9 
10     def _list_images(self, root, label):  # label是一个list
11         self.synsets = []
12         self.synsets.append(root)
13         self.items = []
14         #file = open(label)
15         #lines = file.readlines()
16         #random.shuffle(lines)
17         c = 0
18         for line in label:
19             cls = line.split()
20             fn = cls.pop(0)
21             fn = fn + '.jpg'
22             # print(os.path.join(root, fn))
23             if os.path.isfile(os.path.join(root, fn)):
24                 self.items.append((os.path.join(root, fn), float(cls[0])))
25                 # print((os.path.join(root, fn), float(cls[0])))
26             else:
27                 print('what')
28             c = c + 1
29         print('the total image is ', c)
30 
31     def __getitem__(self, idx):
32         img = image.imread(self.items[idx][0], self._flag)
33         label = self.items[idx][1]
34         if self._transform is not None:
35             return self._transform(img, label)
36         return img, label
37 
38     def __len__(self):
39         return len(self.items)
40 batch_size = 16
41 root = '/home/ying/data2/shiyongjie/landmark_recognition/data/image'
42 def random_choose_data(label_path):
43     f = open(label_path)
44     lines = f.readlins()
45     random.shuffle(lines)
46     total_number = len(lines)
47     train_number = total_number/10*7
48     train_list = lines[:train_number]
49     test_list = lines[train_number:]
50     return (train_list, test_list)
51 
52 label_path = '/home/ying/data2/shiyongjie/landmark_recognition/data/train.txt'
53 train_list, test_list = random_choose_data(label_path)
54 
55 loader = gluon.data.DataLoader
56 train_ds = MyImageFolderDataset(os.path.join(root, 'image'), train_list, flag=1, transform=transform_train)
57 test_ds = MyImageFolderDataset(os.path.join(root, 'Testing'), test_list, flag=1, transform=transform_test)
58 train_data = loader(train_ds, batch_size, shuffle=True, last_batch='keep')
59 test_data = loader(test_ds, batch_size, shuffle=False, last_batch='keep')
View Code

MyImageFolderDataset是dataset.Dataset的子类,主要是是重载索引运算__getitem__,并且返回image以及其对应的label即可,前面的的_list_image函数只要是能够返回item这个list就行,关于运算符重载给自己挖个坑

可以说和pytorch非常像了,就连沐神在讲课的时候还在说,其实在写mxnet的时候,借鉴了很多pytorch的内容

 

你可能感兴趣的:(mxnet自定义dataloader加载自己的数据)