Scheduled Sampling简单理解

paper link

起因

最近写seq2seq跑时序预测,问题层出不穷,还是基础打的不牢固;早上搜索到时候看到Scheduled Sampling还在疑惑是啥?想看看能不能加到代码里面
扒了好多博客,看的还是云里雾里;看了代码后逐渐明白了

  • Teacher Forcing != Teahcer Forcing Ratio(好多博客里面将这两个混为一谈,直接把我看迷茫了)
  • Teacher Forcing Ratio=Scheduled Sampling

Teacher Forcing

Teacher Forcing 可以理解为:学生请教老师一套卷子上的所有题目,老师想交他的这张卷子上所有题目对应类别的解法,但是学生只关注到目前的这一卷子,在这一张卷子上过程、结果越来越正确;可到了考试的时候,试卷换了,学生只会那一张卷子,考试的时候依旧不及格。
PS:没找到合适的图,找到了再补图吧
可以将请教的部分理解为训练部分,考试的部分理解为验证/测试部分;当然神经网络的学习并不会这么极端,网络的学习会使得结果会像那么一回事;同时也会引发其他问题:Exposure Bias、 Overcorrect等问题,可以看这里知乎专栏

Teacher Forcing Ratio/Scheduled Sampling

我理解的Teacher Forcing Ratio的加入就是scheduled sampling,通过在每一个时间步的输出后更具概率决定下一次的输入:Ground Truth或者Model Output;图中的sampled可以理解为模型的输出
Scheduled Sampling简单理解_第1张图片

class Seq2Seq(nn.Module):
	def __init__(self):
	.....
	def Forward(self,x,y,teacher_raio):
		.....
		output,hidden=self.decoder()
		next_input=output if random.random()<teacher_ratio elif y
		#如果随机数小于teacher ratio使用模型输出值,否则使用真实值
		#请不要固定随机数的种子点,否则就会一直使用真实值或者模型输出值
		#Teacher Forcing Ratio default:0.5
		.....

衰减策略

别人的code

  • Linear: Ratio is decreased by forcing_decay every batch.
  • Exponential: Ratio is multiplied by forcing_decay every batch.
  • Inverse sigmoid: Ratio is k/(k + exp(i/k)) where k is forcing_decay and i is batch number.
  • 当然你也可以设置
    Scheduled Sampling简单理解_第2张图片
    如果写的有问题,欢迎指出!!!

Ref

Scheduled Sampling:RNN的训练trick
[论文解读]Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks
【Scheduled sampling】— 解决训练和预测产生的矛盾
一文弄懂关于循环神经网络(RNN)的Teacher Forcing训练机制

你可能感兴趣的:(Pytorch,pytorch,python)