import tensorflow as tf
import numpy as np
from tensorflow.keras import *
import os, sys, glob, shutil, json
from tensorflow import keras
physical_device = tf.config.experimental.list_physical_devices("GPU")
tf.config.experimental.set_memory_growth(physical_device[0], True)
BATCH_SIZE = 32
train_json = json.load(open('/home/tralia/Downloads/SVHN_CLF/train.json'))
train_label = [train_json[x]['label'] for x in train_json]
train_label_5 = []
for i in range (len(train_label)):
train_label_i = train_label[i] + (5 - len(train_label[i])) * [10]
train_label_i = train_label_i[:5]
train_label_5.append(train_label_i)
val_json = json.load(open('/home/tralia/Downloads/SVHN_CLF/val.json'))
val_label = [val_json[x]['label'] for x in val_json]
val_label_5 = []
for i in range (len(val_label)):
val_label_i = val_label[i] + (5 - len(val_label[i])) * [10]
val_label_i = val_label_i[:5]
val_label_5.append(val_label_i)
def load_train_image(img_path):
img = tf.io.read_file(img_path)
img = tf.image.decode_png(img,channels=3)
img = tf.image.resize(img,(64,128))/255.0
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.random_crop(img, size=[60, 120, 3])
img = tf.image.random_brightness(img, max_delta=0.5)
return img
def load_val_image(img_path):
img = tf.io.read_file(img_path)
img = tf.image.decode_png(img,channels=3)
img = tf.image.resize(img,(60,120))/255.0
img = tf.image.convert_image_dtype(img, tf.float32)
return img
train_batches_img = tf.data.Dataset.list_files("/home/tralia/Downloads/SVHN_CLF/train/*.png") .map(load_train_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_batches_img = tf.data.Dataset.list_files("/home/tralia/Downloads/SVHN_CLF/val/*.png") .map(load_val_image, num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_batches_lbl = tf.data.Dataset.from_tensor_slices(np.array(train_label_5))
val_batches_lbl = tf.data.Dataset.from_tensor_slices(np.array(val_label_5))
train_batches = tf.data.Dataset.zip((train_batches_img,train_batches_lbl))
train_batches = train_batches.shuffle(buffer_size = 10000).batch(BATCH_SIZE) .prefetch(tf.data.experimental.AUTOTUNE)
val_batches = tf.data.Dataset.zip((val_batches_img,val_batches_lbl))
val_batches = val_batches.shuffle(buffer_size = 1).batch(BATCH_SIZE) .prefetch(tf.data.experimental.AUTOTUNE)
class SVHNModel(keras.Model):
def __init__(self):
super(SVHNModel, self).__init__()
def build(self,input_shape):
net = keras.applications.ResNet50(weights='imagenet', include_top=False)
net.trainable = True
self.cnn = net
self.average_pooling_layer = layers.GlobalAveragePooling2D()
self.fc1 = layers.Dense(11, activation = 'softmax')
self.fc2 = layers.Dense(11, activation = 'softmax')
self.fc3 = layers.Dense(11, activation = 'softmax')
self.fc4 = layers.Dense(11, activation = 'softmax')
self.fc5 = layers.Dense(11, activation = 'softmax')
super(SVHNModel,self).build(input_shape)
def call(self, inputs):
x = self.cnn(inputs)
x = self.average_pooling_layer(x)
c1 = self.fc1(x)
c2 = self.fc2(x)
c3 = self.fc3(x)
c4 = self.fc4(x)
c5 = self.fc5(x)
return c1, c2, c3, c4, c5
tf.keras.backend.clear_session()
model = SVHNModel()
model.build(input_shape = (None,60, 120, 3))
model.summary()
best_loss = 1000.0
optimizer = optimizers.Adam(lr=1e-2)
variables = model.trainable_variables
for epoch in range(3):
train_loss = []
for step, (x,y) in enumerate(train_batches):
with tf.GradientTape() as tape:
c0, c1, c2, c3, c4 = model(x)
y_onehot_0 = tf.one_hot(y[:,0], depth=11)
y_onehot_1 = tf.one_hot(y[:,1], depth=11)
y_onehot_2 = tf.one_hot(y[:,2], depth=11)
y_onehot_3 = tf.one_hot(y[:,3], depth=11)
y_onehot_4 = tf.one_hot(y[:,4], depth=11)
loss_0 = tf.losses.categorical_crossentropy(y_onehot_0, c0, from_logits=True)
loss_0 = tf.reduce_mean(loss_0)
loss_1 = tf.losses.categorical_crossentropy(y_onehot_1, c1, from_logits=True)
loss_1 = tf.reduce_mean(loss_1)
loss_2 = tf.losses.categorical_crossentropy(y_onehot_2, c2, from_logits=True)
loss_2 = tf.reduce_mean(loss_2)
loss_3 = tf.losses.categorical_crossentropy(y_onehot_3, c3, from_logits=True)
loss_3 = tf.reduce_mean(loss_3)
loss_4 = tf.losses.categorical_crossentropy(y_onehot_4, c4, from_logits=True)
loss_4 = tf.reduce_mean(loss_4)
loss = loss_0+loss_1+loss_2+loss_3+loss_4
train_loss.append(loss)
grads = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(grads, variables))
train_loss = tf.reduce_mean(train_loss)
val_loss = []
val_pred = []
for step, (x,y) in enumerate(val_batches):
c0, c1, c2, c3, c4 = model(x)
y_onehot_0 = tf.one_hot(y[:,0], depth=11)
y_onehot_1 = tf.one_hot(y[:,1], depth=11)
y_onehot_2 = tf.one_hot(y[:,2], depth=11)
y_onehot_3 = tf.one_hot(y[:,3], depth=11)
y_onehot_4 = tf.one_hot(y[:,4], depth=11)
loss_0 = tf.losses.categorical_crossentropy(y_onehot_0, c0, from_logits=True)
loss_0 = tf.reduce_mean(loss_0)
loss_1 = tf.losses.categorical_crossentropy(y_onehot_1, c1, from_logits=True)
loss_1 = tf.reduce_mean(loss_1)
loss_2 = tf.losses.categorical_crossentropy(y_onehot_2, c2, from_logits=True)
loss_2 = tf.reduce_mean(loss_2)
loss_3 = tf.losses.categorical_crossentropy(y_onehot_3, c3, from_logits=True)
loss_3 = tf.reduce_mean(loss_3)
loss_4 = tf.losses.categorical_crossentropy(y_onehot_4, c4, from_logits=True)
loss_4 = tf.reduce_mean(loss_4)
loss = loss_0+loss_1+loss_2+loss_3+loss_4
val_loss.append(loss)
output = np.concatenate([
np.array(c0),
np.array(c1),
np.array(c2),
np.array(c3),
np.array(c4)], axis=1)
val_pred.append(output)
val_predict_label = np.vstack(val_pred)
val_predict_label = np.vstack([
val_predict_label[:, :11].argmax(1),
val_predict_label[:, 11:22].argmax(1),
val_predict_label[:, 22:33].argmax(1),
val_predict_label[:, 33:44].argmax(1),
val_predict_label[:, 44:55].argmax(1),
]).T
val_label_pred = []
for x in val_predict_label:
val_label_pred.append(''.join(map(str, x[x!=10])))
val_label = [''.join(map(str, x)) for x in val_label_5]
val_char_acc = np.mean(np.array(val_label_pred) == np.array(val_label))
val_loss = tf.reduce_mean(val_loss)
print('Epoch: {0}, Train loss: {1} \t Val loss: {2}'.format(epoch, train_loss, val_loss))
print(val_char_acc)
if val_loss < best_loss:
best_loss = val_loss
model.save_weights('./save_weights/my_save_weights')