利用胶囊网络实现对CIFAR10分类

 利用胶囊网络实现对CIFAR10分类

数据集:CIFAR-10数据集由10个类中的60000个32x32彩色图像组成,每个类有6000个图像。有50000个训练图像和10000个测试图像。

利用胶囊网络实现对CIFAR10分类_第1张图片

实验:搭建胶囊网络

from __future__ import print_function
from keras import backend as K
from keras.layers import Layer
from keras import activations
from keras import utils
from keras.datasets import cifar10
from keras.models import Model
from keras.layers import *
from keras.preprocessing.image import ImageDataGenerator


# the squashing function.
# we use 0.5 in stead of 1 in hinton's paper.
# if 1, the norm of vector will be zoomed out.
# if 0.5, the norm will be zoomed in while original norm is less than 0.5
# and be zoomed out while original norm is greater than 0.5.
def squash(x, axis=-1):
    s_squared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon()
    scale = K.sqrt(s_squared_norm) / (0.5 + s_squared_norm)
    return scale * x


# define our own softmax function instead of K.softmax
# because K.softmax can not specify axis.
def softmax(x, axis=-1):
    ex = K.exp(x - K.max(x, axis=axis, keepdims=True))
    return ex / K.sum(ex, axis=axis, keepdims=True)


# define the margin loss like hinge loss
def margin_loss(y_true, y_pred):
    lamb, margin = 0.5, 0.1
    return K.sum(y_true * K.square(K.relu(1 - margin - y_pred)) + lamb * (
        1 - y_true) * K.square(K.relu(y_pred - margin)), axis=-1)


class Capsule(Layer):
    """A Capsule Implement with Pure Keras
    There are two vesions of Capsule.
    One is like dense layer (for the fixed-shape input),
    and the other is like timedistributed dense (for various length input).

    The input shape of Capsule must be (batch_size,
                                        input_num_capsule,
                                        input_dim_capsule
                                       )
    and the output shape is (batch_size,
                             num_capsule,
                             dim_capsule
                            )

    Capsule Implement is from https://github.com/bojone/Capsule/
    Capsule Paper: https://arxiv.org/abs/1710.09829
    """

    def __init__(self,
                 num_capsule,
                 dim_capsule,
                 routings=3,
                 share_weights=True,
                 activation='squash',
                 **kwargs):
        super(Capsule, self).__init__(**kwargs)
        self.num_capsule = num_capsule
        self.dim_capsule = dim_capsule
        self.routings = routings
        self.share_weights = share_weights
        if activation == 'squash':
            self.activation = squash
        else:
            self.activation = activations.get(activation)

    def build(self, input_shape):
        input_dim_capsule = input_shape[-1]
        if self.share_weights:
            self.kernel = self.add_weight(
                name='capsule_kernel',
                shape=(1, input_dim_capsule,
                       self.num_capsule * self.dim_capsule),
                initializer='glorot_uniform',
                trainable=True)
        else:
            input_num_capsule = input_shape[-2]
            self.kernel = self.add_weight(
                name='capsule_kernel',
                shape=(input_num_capsule, input_dim_capsule,
                       self.num_capsule * self.dim_capsule),
                initializer='glorot_uniform',
                trainable=True)

    def call(self, inputs):
        """Following the routing algorithm from Hinton's paper,
        but replace b = b +  with b = .

        This change can improve the feature representation of Capsule.

        However, you can replace
            b = K.batch_dot(outputs, hat_inputs, [2, 3])
        with
            b += K.batch_dot(outputs, hat_inputs, [2, 3])
        to realize a standard routing.
        """

        if self.share_weights:
            hat_inputs = K.conv1d(inputs, self.kernel)
        else:
            hat_inputs = K.local_conv1d(inputs, self.kernel, [1], [1])

        batch_size = K.shape(inputs)[0]
        input_num_capsule = K.shape(inputs)[1]
        hat_inputs = K.reshape(hat_inputs,
                               (batch_size, input_num_capsule,
                                self.num_capsule, self.dim_capsule))
        hat_inputs = K.permute_dimensions(hat_inputs, (0, 2, 1, 3))

        b = K.zeros_like(hat_inputs[:, :, :, 0])
        for i in range(self.routings):
            c = softmax(b, 1)
            o = self.activation(K.batch_dot(c, hat_inputs, [2, 2]))
            if i < self.routings - 1:
                b = K.batch_dot(o, hat_inputs, [2, 3])
                if K.backend() == 'theano':
                    o = K.sum(o, axis=1)

        return o

    def compute_output_shape(self, input_shape):
        return (None, self.num_capsule, self.dim_capsule)


batch_size = 128
num_classes = 10
epochs = 100
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
y_train = utils.to_categorical(y_train, num_classes)
y_test = utils.to_categorical(y_test, num_classes)

# A common Conv2D model
input_image = Input(shape=(None, None, 3))
x = Conv2D(64, (3, 3), activation='relu')(input_image)
x = Conv2D(64, (3, 3), activation='relu')(x)
x = AveragePooling2D((2, 2))(x)
x = Conv2D(128, (3, 3), activation='relu')(x)
x = Conv2D(128, (3, 3), activation='relu')(x)


"""now we reshape it as (batch_size, input_num_capsule, input_dim_capsule)
then connect a Capsule layer.

the output of final model is the lengths of 10 Capsule, whose dim=16.

the length of Capsule is the proba,
so the problem becomes a 10 two-classification problem.
"""

x = Reshape((-1, 128))(x)
capsule = Capsule(10, 16, 3, True)(x)
output = Lambda(lambda z: K.sqrt(K.sum(K.square(z), 2)))(capsule)
model = Model(inputs=input_image, outputs=output)

# we use a margin loss
model.compile(loss=margin_loss, optimizer='adam', metrics=['accuracy'])
model.summary()

# we can compare the performance with or without data augmentation
data_augmentation = True

if not data_augmentation:
    print('Not using data augmentation.')
    model.fit(
        x_train,
        y_train,
        batch_size=batch_size,
        epochs=epochs,
        validation_data=(x_test, y_test),
        shuffle=True)
else:
    print('Using real-time data augmentation.')
    # This will do preprocessing and realtime data augmentation:
    datagen = ImageDataGenerator(
        featurewise_center=False,  # set input mean to 0 over the dataset
        samplewise_center=False,  # set each sample mean to 0
        featurewise_std_normalization=False,  # divide inputs by dataset std
        samplewise_std_normalization=False,  # divide each input by its std
        zca_whitening=False,  # apply ZCA whitening
        zca_epsilon=1e-06,  # epsilon for ZCA whitening
        rotation_range=0,  # randomly rotate images in 0 to 180 degrees
        width_shift_range=0.1,  # randomly shift images horizontally
        height_shift_range=0.1,  # randomly shift images vertically
        shear_range=0.,  # set range for random shear
        zoom_range=0.,  # set range for random zoom
        channel_shift_range=0.,  # set range for random channel shifts
        # set mode for filling points outside the input boundaries
        fill_mode='nearest',
        cval=0.,  # value used for fill_mode = "constant"
        horizontal_flip=True,  # randomly flip images
        vertical_flip=False,  # randomly flip images
        # set rescaling factor (applied before any other transformation)
        rescale=None,
        # set function that will be applied on each input
        preprocessing_function=None,
        # image data format, either "channels_first" or "channels_last"
        data_format=None,
        # fraction of images reserved for validation (strictly between 0 and 1)
        validation_split=0.0)

    # Compute quantities required for feature-wise normalization
    # (std, mean, and principal components if ZCA whitening is applied).
    datagen.fit(x_train)

    # Fit the model on the batches generated by datagen.flow().
    model.fit_generator(datagen.flow(x_train, y_train,
                                     batch_size=batch_size),
                        steps_per_epoch=len(x_train)/batch_size,
                        epochs=epochs,
                        validation_data=(x_test, y_test),
                        workers=4)

实验结果:

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 15s 0us/step
170508288/170498071 [==============================] - 15s 0us/step
WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, None, None, 3)     0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, None, None, 64)    1792      
_________________________________________________________________
conv2d_2 (Conv2D)            (None, None, None, 64)    36928     
_________________________________________________________________
average_pooling2d_1 (Average (None, None, None, 64)    0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, None, None, 128)   73856     
_________________________________________________________________
conv2d_4 (Conv2D)            (None, None, None, 128)   147584    
_________________________________________________________________
reshape_1 (Reshape)          (None, None, 128)         0         
_________________________________________________________________
capsule_1 (Capsule)          (None, 10, 16)            20480     
_________________________________________________________________
lambda_1 (Lambda)            (None, 10)                0         
=================================================================
Total params: 280,640
Trainable params: 280,640
Non-trainable params: 0
_________________________________________________________________
Using real-time data augmentation.
WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
WARNING:tensorflow:From /usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/math_grad.py:102: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
Epoch 1/100
390/390 [==============================] - 33s 83ms/step - loss: 0.4269 - acc: 0.3457 - val_loss: 0.3927 - val_acc: 0.4050
Epoch 2/100
390/390 [==============================] - 26s 66ms/step - loss: 0.3531 - acc: 0.4826 - val_loss: 0.3212 - val_acc: 0.5277
Epoch 3/100
390/390 [==============================] - 25s 63ms/step - loss: 0.3079 - acc: 0.5638 - val_loss: 0.2901 - val_acc: 0.5906
Epoch 4/100
390/390 [==============================] - 25s 64ms/step - loss: 0.2767 - acc: 0.6180 - val_loss: 0.2520 - val_acc: 0.6586
Epoch 5/100
390/390 [==============================] - 25s 64ms/step - loss: 0.2501 - acc: 0.6618 - val_loss: 0.2425 - val_acc: 0.6692
Epoch 6/100
390/390 [==============================] - 25s 65ms/step - loss: 0.2329 - acc: 0.6903 - val_loss: 0.2287 - val_acc: 0.6920
Epoch 7/100
390/390 [==============================] - 25s 65ms/step - loss: 0.2192 - acc: 0.7114 - val_loss: 0.2109 - val_acc: 0.7285
Epoch 8/100
390/390 [==============================] - 25s 63ms/step - loss: 0.2070 - acc: 0.7308 - val_loss: 0.2069 - val_acc: 0.7315
Epoch 9/100
390/390 [==============================] - 25s 64ms/step - loss: 0.1971 - acc: 0.7465 - val_loss: 0.1901 - val_acc: 0.7515
Epoch 10/100
390/390 [==============================] - 25s 63ms/step - loss: 0.1902 - acc: 0.7579 - val_loss: 0.2036 - val_acc: 0.7327
Epoch 11/100
390/390 [==============================] - 25s 64ms/step - loss: 0.1822 - acc: 0.7698 - val_loss: 0.1962 - val_acc: 0.7494
Epoch 12/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1767 - acc: 0.7784 - val_loss: 0.1817 - val_acc: 0.7675
Epoch 13/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1714 - acc: 0.7858 - val_loss: 0.1804 - val_acc: 0.7706
Epoch 14/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1670 - acc: 0.7951 - val_loss: 0.1786 - val_acc: 0.7714
Epoch 15/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1617 - acc: 0.8021 - val_loss: 0.1656 - val_acc: 0.7928
Epoch 16/100
390/390 [==============================] - 25s 63ms/step - loss: 0.1586 - acc: 0.8079 - val_loss: 0.1638 - val_acc: 0.7972
Epoch 17/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1546 - acc: 0.8143 - val_loss: 0.1682 - val_acc: 0.7892
Epoch 18/100
390/390 [==============================] - 25s 64ms/step - loss: 0.1510 - acc: 0.8195 - val_loss: 0.1724 - val_acc: 0.7821
Epoch 19/100
390/390 [==============================] - 25s 63ms/step - loss: 0.1493 - acc: 0.8226 - val_loss: 0.1682 - val_acc: 0.7886
Epoch 20/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1453 - acc: 0.8296 - val_loss: 0.1558 - val_acc: 0.8107
Epoch 21/100
390/390 [==============================] - 25s 64ms/step - loss: 0.1413 - acc: 0.8350 - val_loss: 0.1636 - val_acc: 0.7974
Epoch 22/100
390/390 [==============================] - 25s 64ms/step - loss: 0.1397 - acc: 0.8355 - val_loss: 0.1579 - val_acc: 0.8061
Epoch 23/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1369 - acc: 0.8403 - val_loss: 0.1631 - val_acc: 0.8008
Epoch 24/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1358 - acc: 0.8436 - val_loss: 0.1597 - val_acc: 0.8084
Epoch 25/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1316 - acc: 0.8478 - val_loss: 0.1538 - val_acc: 0.8115
Epoch 26/100
390/390 [==============================] - 26s 67ms/step - loss: 0.1311 - acc: 0.8508 - val_loss: 0.1507 - val_acc: 0.8203
Epoch 27/100
390/390 [==============================] - 25s 64ms/step - loss: 0.1281 - acc: 0.8556 - val_loss: 0.1549 - val_acc: 0.8123
Epoch 28/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1268 - acc: 0.8561 - val_loss: 0.1524 - val_acc: 0.8129
Epoch 29/100
390/390 [==============================] - 26s 65ms/step - loss: 0.1249 - acc: 0.8613 - val_loss: 0.1544 - val_acc: 0.8149
Epoch 30/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1238 - acc: 0.8608 - val_loss: 0.1715 - val_acc: 0.7860
Epoch 31/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1209 - acc: 0.8669 - val_loss: 0.1504 - val_acc: 0.8167
Epoch 32/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1211 - acc: 0.8648 - val_loss: 0.1635 - val_acc: 0.7936
Epoch 33/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1187 - acc: 0.8693 - val_loss: 0.1536 - val_acc: 0.8131
Epoch 34/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1166 - acc: 0.8724 - val_loss: 0.1557 - val_acc: 0.8132
Epoch 35/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1172 - acc: 0.8718 - val_loss: 0.1559 - val_acc: 0.8074
Epoch 36/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1146 - acc: 0.8757 - val_loss: 0.1577 - val_acc: 0.8087
Epoch 37/100
390/390 [==============================] - 25s 63ms/step - loss: 0.1132 - acc: 0.8792 - val_loss: 0.1536 - val_acc: 0.8155
Epoch 38/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1111 - acc: 0.8816 - val_loss: 0.1519 - val_acc: 0.8236
Epoch 39/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1115 - acc: 0.8804 - val_loss: 0.1580 - val_acc: 0.8090
Epoch 40/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1094 - acc: 0.8837 - val_loss: 0.1560 - val_acc: 0.8105
Epoch 41/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1091 - acc: 0.8839 - val_loss: 0.1559 - val_acc: 0.8098
Epoch 42/100
390/390 [==============================] - 25s 65ms/step - loss: 0.1067 - acc: 0.8890 - val_loss: 0.1533 - val_acc: 0.8162
Epoch 43/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1063 - acc: 0.8886 - val_loss: 0.1524 - val_acc: 0.8160
Epoch 44/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1052 - acc: 0.8896 - val_loss: 0.1493 - val_acc: 0.8250
Epoch 45/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1044 - acc: 0.8906 - val_loss: 0.1618 - val_acc: 0.8011
Epoch 46/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1036 - acc: 0.8928 - val_loss: 0.1530 - val_acc: 0.8208
Epoch 47/100
390/390 [==============================] - 26s 65ms/step - loss: 0.1026 - acc: 0.8935 - val_loss: 0.1505 - val_acc: 0.8204
Epoch 48/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1011 - acc: 0.8970 - val_loss: 0.1483 - val_acc: 0.8232
Epoch 49/100
390/390 [==============================] - 26s 66ms/step - loss: 0.1008 - acc: 0.8969 - val_loss: 0.1538 - val_acc: 0.8153
Epoch 50/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0997 - acc: 0.8981 - val_loss: 0.1541 - val_acc: 0.8155
Epoch 51/100
390/390 [==============================] - 25s 64ms/step - loss: 0.0985 - acc: 0.8996 - val_loss: 0.1511 - val_acc: 0.8212
Epoch 52/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0975 - acc: 0.9027 - val_loss: 0.1540 - val_acc: 0.8121
Epoch 53/100
390/390 [==============================] - 26s 65ms/step - loss: 0.0979 - acc: 0.9002 - val_loss: 0.1508 - val_acc: 0.8183
Epoch 54/100
390/390 [==============================] - 25s 64ms/step - loss: 0.0978 - acc: 0.9011 - val_loss: 0.1494 - val_acc: 0.8213
Epoch 55/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0956 - acc: 0.9060 - val_loss: 0.1535 - val_acc: 0.8169
Epoch 56/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0961 - acc: 0.9048 - val_loss: 0.1575 - val_acc: 0.8074
Epoch 57/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0940 - acc: 0.9081 - val_loss: 0.1471 - val_acc: 0.8239
Epoch 58/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0946 - acc: 0.9080 - val_loss: 0.1579 - val_acc: 0.8147
Epoch 59/100
390/390 [==============================] - 26s 67ms/step - loss: 0.0937 - acc: 0.9090 - val_loss: 0.1513 - val_acc: 0.8209
Epoch 60/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0920 - acc: 0.9126 - val_loss: 0.1469 - val_acc: 0.8216
Epoch 61/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0921 - acc: 0.9102 - val_loss: 0.1458 - val_acc: 0.8256
Epoch 62/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0910 - acc: 0.9120 - val_loss: 0.1473 - val_acc: 0.8312
Epoch 63/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0898 - acc: 0.9159 - val_loss: 0.1514 - val_acc: 0.8146
Epoch 64/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0905 - acc: 0.9124 - val_loss: 0.1544 - val_acc: 0.8228
Epoch 65/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0898 - acc: 0.9141 - val_loss: 0.1487 - val_acc: 0.8275
Epoch 66/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0884 - acc: 0.9158 - val_loss: 0.1521 - val_acc: 0.8209
Epoch 67/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0891 - acc: 0.9151 - val_loss: 0.1538 - val_acc: 0.8143
Epoch 68/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0891 - acc: 0.9171 - val_loss: 0.1500 - val_acc: 0.8190
Epoch 69/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0859 - acc: 0.9196 - val_loss: 0.1492 - val_acc: 0.8225
Epoch 70/100
390/390 [==============================] - 25s 64ms/step - loss: 0.0864 - acc: 0.9194 - val_loss: 0.1511 - val_acc: 0.8243
Epoch 71/100
390/390 [==============================] - 26s 65ms/step - loss: 0.0867 - acc: 0.9204 - val_loss: 0.1613 - val_acc: 0.8082
Epoch 72/100
390/390 [==============================] - 25s 63ms/step - loss: 0.0857 - acc: 0.9200 - val_loss: 0.1515 - val_acc: 0.8193
Epoch 73/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0847 - acc: 0.9228 - val_loss: 0.1684 - val_acc: 0.7938
Epoch 74/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0848 - acc: 0.9208 - val_loss: 0.1517 - val_acc: 0.8175
Epoch 75/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0852 - acc: 0.9219 - val_loss: 0.1482 - val_acc: 0.8249
Epoch 76/100
390/390 [==============================] - 25s 64ms/step - loss: 0.0824 - acc: 0.9249 - val_loss: 0.1522 - val_acc: 0.8176
Epoch 77/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0836 - acc: 0.9237 - val_loss: 0.1481 - val_acc: 0.8220
Epoch 78/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0840 - acc: 0.9224 - val_loss: 0.1545 - val_acc: 0.8208
Epoch 79/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0833 - acc: 0.9253 - val_loss: 0.1530 - val_acc: 0.8210
Epoch 80/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0810 - acc: 0.9265 - val_loss: 0.1517 - val_acc: 0.8218
Epoch 81/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0820 - acc: 0.9244 - val_loss: 0.1555 - val_acc: 0.8191
Epoch 82/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0825 - acc: 0.9249 - val_loss: 0.1586 - val_acc: 0.8140
Epoch 83/100
390/390 [==============================] - 26s 65ms/step - loss: 0.0797 - acc: 0.9286 - val_loss: 0.1569 - val_acc: 0.8168
Epoch 84/100
390/390 [==============================] - 26s 65ms/step - loss: 0.0804 - acc: 0.9286 - val_loss: 0.1503 - val_acc: 0.8282
Epoch 85/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0794 - acc: 0.9297 - val_loss: 0.1565 - val_acc: 0.8178
Epoch 86/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0800 - acc: 0.9285 - val_loss: 0.1553 - val_acc: 0.8168
Epoch 87/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0796 - acc: 0.9288 - val_loss: 0.1554 - val_acc: 0.8236
Epoch 88/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0796 - acc: 0.9294 - val_loss: 0.1467 - val_acc: 0.8288
Epoch 89/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0785 - acc: 0.9298 - val_loss: 0.1577 - val_acc: 0.8142
Epoch 90/100
390/390 [==============================] - 25s 64ms/step - loss: 0.0787 - acc: 0.9312 - val_loss: 0.1624 - val_acc: 0.8103
Epoch 91/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0774 - acc: 0.9335 - val_loss: 0.1609 - val_acc: 0.8121
Epoch 92/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0767 - acc: 0.9348 - val_loss: 0.1488 - val_acc: 0.8230
Epoch 93/100
390/390 [==============================] - 25s 64ms/step - loss: 0.0767 - acc: 0.9338 - val_loss: 0.1482 - val_acc: 0.8243
Epoch 94/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0772 - acc: 0.9334 - val_loss: 0.1510 - val_acc: 0.8212
Epoch 95/100
390/390 [==============================] - 25s 64ms/step - loss: 0.0752 - acc: 0.9344 - val_loss: 0.1546 - val_acc: 0.8197
Epoch 96/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0760 - acc: 0.9339 - val_loss: 0.1537 - val_acc: 0.8196
Epoch 97/100
390/390 [==============================] - 26s 67ms/step - loss: 0.0759 - acc: 0.9346 - val_loss: 0.1526 - val_acc: 0.8225
Epoch 98/100
390/390 [==============================] - 25s 64ms/step - loss: 0.0748 - acc: 0.9371 - val_loss: 0.1531 - val_acc: 0.8196
Epoch 99/100
390/390 [==============================] - 25s 65ms/step - loss: 0.0751 - acc: 0.9366 - val_loss: 0.1530 - val_acc: 0.8239
Epoch 100/100
390/390 [==============================] - 26s 66ms/step - loss: 0.0745 - acc: 0.9379 - val_loss: 0.1544 - val_acc: 0.8219

具体代码:

https://github.com/leonorand/capsule_network/blob/master/CNN_Capsule_CIFAR10.ipynb

你可能感兴趣的:(神经网络,深度学习,胶囊网络)