pytorch笔记:10) 源代码跟读 On the Automatic Generation of Medical Imaging Reports

论文地址:https://arxiv.org/pdf/1711.08195.pdf
github地址:https://github.com/ZexinYan/Medical-Report-Generation

文章目录

          • 模型框架图
          • 模型简介
          • 数据部分
          • 模型部分
          • 模型训练

模型框架图

pytorch笔记:10) 源代码跟读 On the Automatic Generation of Medical Imaging Reports_第1张图片

模型简介

给定一张CT图片,使用CNN对其进行多标签分类,模型最后一层输出(1维向量)作为Visual Features(VF),对分类结果取概率最大k个类别索引(topk)作为Semantic Features(SF),对VF和SF使用注意力机制Co-Attention获取一个上下文向量ctx,ctx通过SentenceLSTM获取 s t o p stop stop向量和 t o p i c topic topic向量,前者是输出结束标志,后者通过WordLSTM输出报告描述。

数据部分

数据部分主要在dataset.py文件中,作者提供的栗子使用了4张图片,打印下数据:

for i, (image, image_id, label, target, prob) in enumerate(data_loader):
	print("image.shape",image.shape)
	print("image_id",image_id)
	print("label.shape",label.shape)
	print("target",target.shape)
	print("prob",prob)
image.shape torch.Size([4, 3, 224, 224]) #4张RGB的图片
image_id ('CXR1972_IM-0633-1001.png', 'CXR932_IM-2430-3001.png', 'CXR1149_IM-0101-1001.png', 'CXR3976_IM-2035-1001.png') #图片名称
label.shape torch.Size([4, 210]) #4张图片的类别标签
target (4, 6, 18)#targer:report由多句话构成(中间用‘.’切分开),每句包含多个单词,这里6=该批次下report最多的句数,18=该批次下句中最大的单词数目
prob #0表示在第几句话时停止生成下一句话
[[1. 1. 1. 1. 1. 0.]
 [1. 1. 1. 1. 1. 0.]
 [1. 1. 1. 1. 1. 0.]
 [1. 1. 1. 1. 1. 0.]]
模型部分

1.CNN模型
模型只看核心的forward函数,默认cnn主模型为resnet152

class VisualFeatureExtractor(nn.Module):
    def forward(self, images):
		#池化之前的特征,图里面的VisualFeatures实际为avg_features
        visual_features = self.model(images)
		#self.avg_func=torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) 2维转1维
        avg_features = self.avg_func(visual_features).squeeze()
        return visual_features, avg_features

2.MLC模型
多标签分类模型,一个CT图片可能包含多个标签,其最可能的前k个作为语义特征(这里已经做了词索引映射)

class MLC(nn.Module):
    def forward(self, avg_features):
		#多标签分类
        tags = self.softmax(self.classifier(avg_features))
		#取概率最大k个作为语义特征
        semantic_features = self.embed(torch.topk(tags, self.k)[1])
        return tags, semantic_features

3.CoAttention模型
参考论文里面的公式,对某变量乘以某权重矩阵在代码中则是使用nn.Linear层实现,该模型的forward方法中提供了v1至v5可选实现,这里看最简单的v2

class CoAttention(nn.Module):
    def v2(self, avg_features, semantic_features, h_sent) -> object:
        """
		h_sent:隐藏层
        no bn
        :rtype: object
        """
        W_v = self.W_v(avg_features)
        W_v_h = self.W_v_h(h_sent.squeeze(1))
		#获取图像特征的注意力参数
        alpha_v = self.softmax(self.W_v_att(self.tanh(W_v + W_v_h)))
        v_att = torch.mul(alpha_v, avg_features)  #**这里没有使用sum? 
		#获取语言特征的注意力参数
        W_a_h = self.W_a_h(h_sent)
        W_a = self.W_a(semantic_features)
        alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a))))
        a_att = torch.mul(alpha_a, semantic_features).sum(1)
		#获取2个特征的共同注意力向量
        ctx = self.W_fc(torch.cat([v_att, a_att], dim=1))
        return ctx, alpha_v, alpha_a

4.SentenceLSTM
用LSTM对cxt生成下游report的输入向量topic和停止输出标识p_stop

class SentenceLSTM(nn.Module):
    def v1(self, ctx, prev_hidden_state, states=None):
        """
        v1 (only training)
        :param ctx:上下文向量(来自co_attention)
        :param prev_hidden_state: 自定义的隐藏层状态
        :param states: lstm隐藏状态
        :return:
        """
		#lstm的序列长度为1,ctx扩展1维
        ctx = ctx.unsqueeze(1)
        hidden_state, states = self.lstm(ctx, states)
		#topic用于生成report
        topic = self.W_topic(self.sigmoid(self.bn_t_h(self.W_t_h(hidden_state))
                                          + self.bn_t_ctx(self.W_t_ctx(ctx))))
		#p_stop结束输出标识						
        p_stop = self.W_stop(self.sigmoid(self.bn_stop_s_1(self.W_stop_s_1(prev_hidden_state))
                                          + self.bn_stop_s(self.W_stop_s(hidden_state))))
        return topic, p_stop, hidden_state, states

5.WordLSTM
使用上游的生成topic向量生成报告信息

class WordLSTM(nn.Module):
	def forward(self, topic_vec, captions):
		#captions在训练时为报告标签
		embeddings = self.embed(captions)
		embeddings = torch.cat((topic_vec, embeddings), 1)
		hidden, _ = self.lstm(embeddings)
		outputs = self.linear(hidden[:, -1, :])
		return outputs

模型训练

训练部分在文件trainer.py中,具体看子类LSTMDebugger中,这里有个地方比较有意思:word_mode训练时把target_report和topic向量进行了融合了(transformer也是这样玩的)

def _epoch_train(self):
	tag_loss, stop_loss, word_loss, loss = 0, 0, 0, 0
	self.extractor.train()
	self.mlc.train()
	self.co_attention.train()
	self.sentence_model.train()
	self.word_model.train()
	#取图片
	for i, (images, _, label, captions, prob) in enumerate(self.train_data_loader):
		batch_tag_loss, batch_stop_loss, batch_word_loss, batch_loss = 0, 0, 0, 0
		images = self._to_var(images)

		visual_features, avg_features = self.extractor.forward(images)
		tags, semantic_features = self.mlc.forward(avg_features)

		batch_tag_loss = self.mse_criterion(tags, self._to_var(label, requires_grad=False)).sum()

		sentence_states = None
		#prev_hidden_states在训练过程中没有进行更新,感觉有点突兀
		prev_hidden_states = self._to_var(torch.zeros(images.shape[0], 1, self.args.hidden_size))

		context = self._to_var(torch.Tensor(captions).long(), requires_grad=False)
		prob_real = self._to_var(torch.Tensor(prob).long(), requires_grad=False)
		#取一个CT中report中的每一段
		for sentence_index in range(captions.shape[1]):
			ctx, _, _ = self.co_attention.forward(avg_features,semantic_features,prev_hidden_states)

			topic, p_stop, hidden_states, sentence_states = self.sentence_model.forward(ctx,prev_hidden_states,sentence_states)
			#把不同report段的loss累加
			batch_stop_loss += self.ce_criterion(p_stop.squeeze(), prob_real[:, sentence_index]).sum()
			#取report中一段的词
			for word_index in range(1, captions.shape[2]):
				#这里比较有意思,训练时把target_report和topic融合了
				words = self.word_model.forward(topic, context[:, sentence_index, :word_index])
				word_mask = (context[:, sentence_index, word_index] > 0).float()
				batch_word_loss += (self.ce_criterion(words, context[:, sentence_index, word_index])* word_mask).sum() * (0.9 ** word_index)

		batch_loss = self.args.lambda_tag * batch_tag_loss \
					 + self.args.lambda_stop * batch_stop_loss \
					 + self.args.lambda_word * batch_word_loss

		self.optimizer.zero_grad()
		batch_loss.backward()

你可能感兴趣的:(机器·深度学习)