import matplotlib.pyplot as plt
fig = plt.figure()
ax1 = fig.add_subplot(111)
l1 = ax1.plot(int_epoch2, int_prec12, 'darkcyan', label='test accuracy of ResNet20')
l2 = ax1.plot(int_epoch3, int_prec13, 'darkorange', label='test accuracy of ...')
ax1.set_ylabel('Test accuracy', fontsize=15)
ax1.set_title("...")
ax2 = ax1.twinx() # this is the important function
l3 = ax2.plot(int_epoch1, int_loss1, 'gold', label='training loss of ResNet20')
l4 = ax2.plot(int_epoch3, int_loss3, 'darkmagenta', label='training loss of ...')
ax2.set_xlim(left=0, right=400)
ax2.set_ylim(0, 2.0)
ax2.set_ylabel('Loss', fontsize=15)
# ax2.set_xlabel('Same X for both exp(-x) and ln(x)')
lns = l1 + l2 + l3 + l4
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc='center right')
plt.show()