x = tf.placeholder(dtype=tf.float32, shape=[1, 224, 224, 3])
sess = tf.InteractiveSession()
#build the network
arg_scope = alexnet.alexnet_v2_arg_scope()
with slim.arg_scope(arg_scope):
logits, _ = alexnet.alexnet_v2(x, num_classes=2, is_training=False)
'''
model = net.Alexnet("pre_trained/alexnet.npy")
logits = model.build(x)
'''
saver = tf.train.Saver()
tf.global_variables_initializer().run()
#restore the pre-trained weights
saver.restore(sess, "logs/alexnet_2000.ckpt")
def get_test_result(sess, path, true_result=0, begin=0, end=1000):
error = 0
lit = []
for parent,_,filenames in os.walk(path):
filenames.sort(key=lambda x: int(x[:-4]))
filenames = filenames[begin:end]
for filename in filenames:
file_path = parent + "/" + filename
#preprocess the images, by inception v1, image=2*(image/255.0) - 1.0, scale the image to (-1,1)
image = cv2.resize(cv2.imread(file_path), (224,224))
image = 2*(image / 255.0) - 1.0
img = np.reshape(image, [-1,224,224,3])
# calculate the logits
predict = sess.run(logits,feed_dict={x:img})
predict = np.reshape(predict, [-1])
result = np.argmax(predict, axis=0)
if result!=true_result:
print(predict, file_path)
error += 1
else:
print(predict)
lit.append(predict[0])
# sort the list from small to large, if want to get the reverse result, use lit.sort(reverse=True)
lit.sort()
err = error/len(filenames)
return lit ,err
#get the normal and polyp score list
normal_list, err_normal = get_test_result(sess, "/home/hdl/ALL_IMAGE/dataset/normal", true_result=1)
polyp_list, err_polyp = get_test_result(sess, "/home/hdl/ALL_IMAGE/dataset/polyp", true_result=0, begin=5000, end=-1)
#reverse the list
normal_list = normal_list[::-1]
polyp_list = polyp_list[::-1]
#get the all score list used to preduct threadhold
all_list = normal_list + polyp_list
all_list.sort(reverse=True)
normal = np.array(normal_list)
polyp = np.array(polyp_list)
TPR = []
FPR = []
for threadhold in all_list:
temp_polyp = polyp >= threadhold
temp_normal = normal >= threadhold
tp = np.sum(temp_polyp == 1)
fn = np.sum(temp_polyp == 0)
tn = np.sum(temp_normal == 0)
fp = np.sum(temp_normal == 1)
#the code above can be replaced by this
'''
tp = np.sum(temp_polyp >= threadhold)
'''
tpr = tp/(tp+fn)
fpr = fp/(tn+fp)
TPR.append(tpr)
FPR.append(fpr)
#combine the data and saved as .npy
arr = np.concatenate((FPR,TPR), axis=0)
np.save("Alexnet_scratch.npy", arr)
三、绘制ROC曲线
import numpy as np
from matplotlib import pyplot as plot
alexnet_finetune = np.load("Alexnet_finetune.npy")
Alexnet_scratch = np.load("Alexnet_scratch.npy")
VGG = np.load("VGG.npy")
alexnet_finetune = np.reshape(alexnet_finetune, newshape=[2,-1])
Alexnet_scratch = np.reshape(Alexnet_scratch, newshape=[2,-1])
VGG = np.reshape(VGG, newshape=[2,-1])
plot.title("Receive operating characteristic curve")
plot.xlabel("False Positive Rate")
plot.ylabel("True Positive Rate")
plot.plot(alexnet_finetune[0],alexnet_finetune[1], color="red", label="Alexnet finetune")
plot.plot(Alexnet_scratch[0], Alexnet_scratch[1], color="blue", label="Alexnet scratch")
plot.plot(VGG[0], VGG[1], color="green", label="VGG with global avgpool")
plot.legend()
plot.show()