cifar10数据集的下载地址为:http://www.cs.toronto.edu/~kriz/cifar.html
下载python version
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
def unpickle(file):
import pickle
with open(file, 'rb') as fo:
dict = pickle.load(fo, encoding='latin1')
return dict
def onehot(labels):
n_sample = len(labels)
n_class = max(labels) +1
onehot_labels = np.zeros((n_sample, n_class))
onehot_labels[np.arange(n_sample),labels]=1
return onehot_labels
data1 = unpickle('cifar10/data_batch_1')
data2 = unpickle('cifar10/data_batch_2')
data3 = unpickle('cifar10/data_batch_3')
data4 = unpickle('cifar10/data_batch_4')
data5 = unpickle('cifar10/data_batch_5')
X_train = np.concatenate((data1['data'], data2['data'], data3['data'],data4['data'], data5['data']), axis=0)
Y_train = np.concatenate((data1['labels'], data2['labels'], data3['labels'],data4['labels'], data5['labels']), axis=0)
Y_train = onehot(Y_train)
test = unpickle('cifar10/test_batch')
X_test = test['data'][:5000, :]
Y_test = onehot(test['labels'])[:5000, :]
print(X_train.shape)
print(Y_train.shape)
print(X_test.shape)
print(Y_test.shape)
xs = tf.placeholder(tf.float32, [None, 32*32*3])
ys = tf.placeholder(tf.float32, [None, 10])
x_image = tf.reshape(xs, [-1, 32, 32, 3])
conv1 = tf.layers.conv2d(x_image, 32, 5, 1, 'same', activation=tf.nn.relu)
poo11 = tf.layers.average_pooling2d(conv1, 3, 2, padding='same')
norm1 = tf.nn.lrn(poo11, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)
conv2 = tf.layers.conv2d(norm1, 64, 5, 1, 'same', activation=tf.nn.relu)
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001/9.0, beta=0.75)
pool2 =tf.layers.average_pooling2d(norm2, 3, 2, padding='same')
pool2_flat = tf.reshape(pool2, [-1, 8*8*64])
fc1 = tf.layers.dense(pool2_flat, 384, activation=tf.nn.relu)
fc2 = tf.layers.dense(fc1, 192, activation=tf.nn.relu)
output = tf.layers.dense(fc2, 10, activation=tf.nn.softmax)
loss= tf.losses.softmax_cross_entropy(onehot_labels=ys,logits=output)
train_step = tf.train.GradientDescentOptimizer(learning_rate=1e-3).minimize(loss)
correct_prediction = tf.equal(tf.arg_max(output, 1), tf.arg_max(ys, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
batch_size = 50
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
total_batch = int(X_train.shape[0]/batch_size)
for i in range(400):
for batch in range(total_batch):
batch_x =X_train[batch*batch_size:(batch+1)*batch_size,:]
batch_y =Y_train[batch * batch_size:(batch + 1) * batch_size, :]
sess.run(train_step, feed_dict={xs: batch_x, ys: batch_y})
acc = sess.run(accuracy, feed_dict={xs: batch_x, ys: batch_y})
print(i, acc)
运行400次的结果
0 0.3
1 0.38
2 0.4
3 0.4
4 0.44
5 0.44
6 0.42
7 0.44
8 0.46
9 0.46
10 0.48
11 0.5
12 0.5
13 0.52
14 0.52
15 0.52
16 0.52
17 0.52
18 0.54
19 0.54
20 0.52
21 0.56
22 0.56
23 0.56
24 0.56
25 0.56
26 0.56
27 0.56
28 0.56
29 0.56
30 0.58
31 0.56
32 0.56
33 0.56
34 0.56
35 0.56
36 0.58
37 0.56
38 0.54
39 0.56
40 0.58
41 0.62
42 0.6
43 0.58
44 0.58
45 0.62
46 0.58
47 0.58
48 0.58
49 0.6
50 0.58
51 0.64
52 0.58
53 0.66
54 0.6
55 0.68
56 0.66
57 0.68
58 0.68
59 0.7
60 0.7
61 0.62
62 0.62
63 0.66
64 0.64
65 0.66
66 0.66
67 0.68
68 0.68
69 0.68
70 0.68
71 0.7
72 0.7
73 0.68
74 0.7
75 0.7
76 0.7
77 0.7
78 0.68
79 0.7
80 0.68
81 0.7
82 0.68
83 0.68
84 0.68
85 0.7
86 0.7
87 0.7
88 0.7
89 0.7
90 0.7
91 0.7
92 0.7
93 0.7
94 0.7
95 0.7
96 0.7
97 0.7
98 0.7
99 0.7
100 0.7
101 0.7
102 0.7
103 0.7
104 0.7
105 0.7
106 0.7
107 0.7
108 0.7
109 0.7
110 0.7
111 0.7
112 0.7
113 0.7
114 0.7
115 0.7
116 0.7
117 0.7
118 0.7
119 0.7
120 0.7
121 0.7
122 0.7
123 0.7
124 0.7
125 0.7
126 0.7
127 0.7
128 0.7
129 0.7
130 0.7
131 0.7
132 0.7
133 0.7
134 0.7
135 0.7
136 0.72
137 0.72
138 0.72
139 0.72
140 0.72
141 0.72
142 0.72
143 0.72
144 0.72
145 0.72
146 0.72
147 0.72
148 0.72
149 0.72
150 0.72
151 0.72
152 0.72
153 0.72
154 0.72
155 0.72
156 0.72
157 0.72
158 0.72
159 0.72
160 0.72
161 0.72
162 0.72
163 0.72
164 0.72
165 0.72
166 0.72
167 0.72
168 0.72
169 0.72
170 0.72
171 0.72
172 0.72
173 0.72
174 0.72
175 0.72
176 0.72
177 0.72
178 0.72
179 0.72
180 0.72
181 0.72
182 0.72
183 0.72
184 0.72
185 0.72
186 0.72
187 0.72
188 0.72
189 0.72
190 0.72
191 0.72
192 0.72
193 0.72
194 0.72
195 0.72
196 0.72
197 0.74
198 0.72
199 0.72
200 0.72
201 0.72
202 0.72
203 0.72
204 0.74
205 0.72
206 0.74
207 0.74
208 0.72
209 0.74
210 0.74
211 0.74
212 0.74
213 0.74
214 0.74
215 0.74
216 0.74
217 0.74
218 0.74
219 0.74
220 0.74
221 0.74
222 0.74
223 0.74
224 0.74
225 0.74
226 0.74
227 0.74
228 0.74
229 0.74
230 0.74
231 0.74
232 0.74
233 0.74
234 0.74
235 0.74
236 0.74
237 0.74
238 0.74
239 0.74
240 0.76
241 0.74
242 0.74
243 0.76
244 0.74
245 0.74
246 0.74
247 0.76
248 0.74
249 0.76
250 0.76
251 0.74
252 0.74
253 0.76
254 0.76
255 0.76
256 0.76
257 0.76
258 0.76
259 0.76
260 0.76
261 0.76
262 0.76
263 0.76
264 0.76
265 0.76
266 0.76
267 0.76
268 0.76
269 0.76
270 0.76
271 0.76
272 0.76
273 0.76
274 0.76
275 0.76
276 0.76
277 0.76
278 0.76
279 0.76
280 0.76
281 0.76
282 0.76
283 0.76
284 0.76
285 0.76
286 0.76
287 0.76
288 0.76
289 0.76
290 0.76
291 0.76
292 0.76
293 0.76
294 0.76
295 0.76
296 0.76
297 0.76
298 0.76
299 0.76
300 0.76
301 0.76
302 0.76
303 0.76
304 0.76
305 0.76
306 0.76
307 0.76
308 0.76
309 0.76
310 0.76
311 0.76
312 0.76
313 0.76
314 0.76
315 0.76
316 0.76
317 0.76
318 0.76
319 0.76
320 0.76
321 0.76
322 0.76
323 0.76
324 0.76
325 0.76
326 0.76
327 0.76
328 0.76
329 0.76
330 0.76
331 0.76
332 0.76
333 0.76
334 0.76
335 0.76
336 0.76
337 0.76
338 0.76
339 0.76
340 0.76
341 0.76
342 0.76
343 0.76
344 0.76
345 0.76
346 0.76
347 0.76
348 0.76
349 0.76
350 0.76
351 0.76
352 0.76
353 0.76
354 0.76
355 0.76
356 0.76
357 0.76
358 0.76
359 0.76
360 0.76
361 0.76
362 0.76
363 0.76
364 0.76
365 0.76
366 0.76
367 0.76
368 0.76
369 0.76
370 0.76
371 0.76
372 0.76
373 0.76
374 0.76
375 0.76
376 0.76
377 0.76
378 0.76
379 0.76
380 0.76
381 0.76
382 0.76
383 0.76
384 0.76
385 0.76
386 0.76
387 0.76
388 0.76
389 0.76
390 0.76
391 0.76
392 0.76
393 0.76
394 0.76
395 0.76
396 0.76
397 0.76
398 0.76
399 0.76