Tensorflow多输出模型

数据集

from tensorflow import keras
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import IPython.display as display
import random
import pathlib

data_dir = '../input/multi-output-model-dataset/multi-output-classification/dataset'
data_root=pathlib.Path(data_dir)
# print(data_root)
# for item in data_root.iterdir():
#     print(item)
all_image_paths=list(data_root.glob('*/*'))
image_count=len(all_image_paths)

all_image_paths=[str(path) for path in all_image_paths]
random.shuffle(all_image_paths)
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
print(label_names)

#取颜色
color_label_names = set(name.split('_')[0] for name in label_names)
print(color_label_names)
#取类型
item_label_names = set(name.split('_')[1] for name in label_names)
print(item_label_names)
#对颜色对应编码
color_label_to_index = dict((name, index) for index,name in enumerate(color_label_names))
print(color_label_to_index)
#对类型进行编码
item_label_to_index = dict((name, index) for index,name in enumerate(item_label_names))
print(item_label_to_index)

all_image_labels = [pathlib.Path(path).parent.name for path in all_image_paths]
color_labels = [color_label_to_index[label.split('_')[0]] for label in all_image_labels]
item_labels = [item_label_to_index[label.split('_')[1]] for label in all_image_labels]

def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [224, 224])
    image = tf.cast(image, tf.float32)
    image = image/255.0  # normalize to [0,1] range
    image = 2*image-1
    return image

path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
AUTOTUNE = tf.data.experimental.AUTOTUNE
image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
label_ds = tf.data.Dataset.from_tensor_slices((color_labels, item_labels))
image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
test_count = int(image_count*0.2)
train_count = image_count - test_count
train_data = image_label_ds.skip(test_count)
test_data = image_label_ds.take(test_count)
BATCH_SIZE = 16
train_data = train_data.shuffle(buffer_size=train_count).repeat(-1)
train_data = train_data.batch(BATCH_SIZE)
train_data = train_data.prefetch(buffer_size=AUTOTUNE)
test_data = test_data.batch(BATCH_SIZE)


mobile_net = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
                                               include_top=False,
                                               weights='imagenet')

mobile_net.trianable = False
inputs = tf.keras.Input(shape=(224, 224, 3))
x = mobile_net(inputs)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x1 = tf.keras.layers.Dense(1024, activation='relu')(x)
out_color = tf.keras.layers.Dense(len(color_label_names),
                                  activation='softmax',
                                  name='out_color')(x1)
x2 = tf.keras.layers.Dense(1024, activation='relu')(x)
out_item = tf.keras.layers.Dense(len(item_label_names),
                                 activation='softmax',
                                 name='out_item')(x2)
model = tf.keras.Model(inputs=inputs,
                       outputs=[out_color, out_item])
print(model.summary())

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss={
     'out_color':'sparse_categorical_crossentropy',
                    'out_item':'sparse_categorical_crossentropy'},
              metrics=['acc']
)
train_steps = train_count//BATCH_SIZE
test_steps = test_count//BATCH_SIZE
model.fit(train_data,
          epochs=15,
          steps_per_epoch=train_steps,
          validation_data=test_data,
          validation_steps=test_steps
)
test_img = tf.expand_dims(load_and_preprocess_image("../input/multi-output-model-dataset/multi-output-classification/dataset/blue_jeans/00000012.jpg"), 0)
pre=model.predict(test_img)
pre_color = np.argmax(pre[0][0])
pre_item = np.argmax(pre[1][0])
print(pre_color, pre_item)
plt.imshow((test_img[0]+1)/2)

你可能感兴趣的:(深度学习,深度学习,python,tensorflow,zigbee,cobol)