--整点歪门邪道--
这两天做一个多目标点识别的任务,在尝试了多任务训练后,发现~多任务的性能要远好于单任务训练。但是发现了一个问题:
然而!训练过程中,y_pred实际上是每一个输出对应一个评价指标,例如:
并没有一个指标可以告诉我!这堆点当中和哪个偏差是最大的!(做过多任务的应该都见过这种输出,偶尔会由这种困扰)
当然,也有解决办法~你瞅一眼不就行了~也没有多少结果~
实际上,这个问题其实影响还是挺多的,如(后面大家如果觉得还有一些场景或者问题需要用到全局指标进行评价,也可以写在评论区):
该任务后续有retrain的需要,而且懒,不想一直盯着训练过程,因此选择使用checkpoint保存模型。但是本次任务的需求其实是要求所有点的误差尽量小,而不是平均误差最小,即实际指标其实是令这几个目标点中误差最大的点最小,因此必须要获取这几个点的欧式距离误差,并且给出其中的最大值~
OK~现在,歪门邪道开始!
众所都周知,tensorflow有个keras~keras里有个callback.Callback~
翻译翻译~就是你训练过程中当前epoch或者batch的全部数据我这里都有,你来拿吧,我还是个字典!
打印一下,看看里面到底都有啥~
class Show_logs(keras.callbacks.Callback):
def __init__(self):
super().__init__()
def on_epoch_end(self, epoch, logs=None):
print(logs)
在model.fit的callback里面配置下:
model.fit(x=self.train_box['x'],y=tuple(self.train_box['y']),
batch_size=self.train_setting['batchsize'], epochs=self.train_setting['epoch'],validation_data(self.test_box['x'],tuple(self.test_box['y'])),validation_freq=1, verbose=2, shuffle=True,callbacks=[Show_logs])
运行,得到输出如下(数据有点异常,不重要!):
中所都周知x2!有了数据取出来算一下不就可以得到这几个点的最大误差是啥了嘛!
class Multy_Maximum_deviation(keras.callbacks.Callback):
def __init__(self,target_list):
super(Multy_Maximum_deviation,self).__init__()
self.target_list=target_list
def on_epoch_end(self, epoch, logs=None):
div=max([logs['val_{0}_偏差'.format(_)] for _ in self.target_list])
print('最大偏差:%.3f'%div)
拿到数据,运行下,发现结果符合要求~至此!你就实现了从logs里面取出每一个epoch的结果并且计算得到自己想要的整体性能指标~
但是!这个指标只能瞅一瞅,那我需要用checkpoint保存模型怎么办??
众所都周知x3,作为一个废物,当无计可施并且网上找不到的时候,就只能查源码了。
打开keras.callbacks.py~在ModelCheckpoint中发现了这个:
跳转过去看看~
soga~原来checkpoint里面的评价指标值也是来源于logs,那这就好办了~
众所都周知x4,配置到callbacks里面的方法必然有执行顺序,从左到右嘛~那我们是不是可以在checkpoint被调用之前,将我们计算得到的"最大偏差"加入logs中,这样不就可以保存了么~
修改一下之前的程序~加一行代码:
class Multy_Maximum_deviation(keras.callbacks.Callback):
def __init__(self,target_list):
super(Multy_Maximum_deviation,self).__init__()
self.target_list=target_list
def on_epoch_end(self, epoch, logs=None):
div=max([logs['val_{0}_偏差'.format(_)] for _ in self.target_list])
logs['最大偏差']=div
print('最大偏差:%.3f'%div)
同时修改下checkpoint:
Check_points=keras.callbacks.ModelCheckpoint(filepath=model_savepath,monitor='最大偏差',verbose=1,save_best_only=True,mode='min')
运行~看一下效果:
完美~偷懒成功~!
感谢观看~