论文地址:https://arxiv.org/pdf/1711.08195.pdf
github地址:https://github.com/ZexinYan/Medical-Report-Generation
给定一张CT图片,使用CNN对其进行多标签分类,模型最后一层输出(1维向量)作为Visual Features(VF),对分类结果取概率最大k个类别索引(topk)作为Semantic Features(SF),对VF和SF使用注意力机制Co-Attention获取一个上下文向量ctx,ctx通过SentenceLSTM获取 s t o p stop stop向量和 t o p i c topic topic向量,前者是输出结束标志,后者通过WordLSTM输出报告描述。
数据部分主要在dataset.py文件中,作者提供的栗子使用了4张图片,打印下数据:
for i, (image, image_id, label, target, prob) in enumerate(data_loader):
print("image.shape",image.shape)
print("image_id",image_id)
print("label.shape",label.shape)
print("target",target.shape)
print("prob",prob)
image.shape torch.Size([4, 3, 224, 224]) #4张RGB的图片
image_id ('CXR1972_IM-0633-1001.png', 'CXR932_IM-2430-3001.png', 'CXR1149_IM-0101-1001.png', 'CXR3976_IM-2035-1001.png') #图片名称
label.shape torch.Size([4, 210]) #4张图片的类别标签
target (4, 6, 18)#targer:report由多句话构成(中间用‘.’切分开),每句包含多个单词,这里6=该批次下report最多的句数,18=该批次下句中最大的单词数目
prob #0表示在第几句话时停止生成下一句话
[[1. 1. 1. 1. 1. 0.]
[1. 1. 1. 1. 1. 0.]
[1. 1. 1. 1. 1. 0.]
[1. 1. 1. 1. 1. 0.]]
1.CNN模型
模型只看核心的forward函数,默认cnn主模型为resnet152
class VisualFeatureExtractor(nn.Module):
def forward(self, images):
#池化之前的特征,图里面的VisualFeatures实际为avg_features
visual_features = self.model(images)
#self.avg_func=torch.nn.AvgPool2d(kernel_size=7, stride=1, padding=0) 2维转1维
avg_features = self.avg_func(visual_features).squeeze()
return visual_features, avg_features
2.MLC模型
多标签分类模型,一个CT图片可能包含多个标签,其最可能的前k个作为语义特征(这里已经做了词索引映射)
class MLC(nn.Module):
def forward(self, avg_features):
#多标签分类
tags = self.softmax(self.classifier(avg_features))
#取概率最大k个作为语义特征
semantic_features = self.embed(torch.topk(tags, self.k)[1])
return tags, semantic_features
3.CoAttention模型
参考论文里面的公式,对某变量乘以某权重矩阵在代码中则是使用nn.Linear层实现,该模型的forward方法中提供了v1至v5可选实现,这里看最简单的v2
class CoAttention(nn.Module):
def v2(self, avg_features, semantic_features, h_sent) -> object:
"""
h_sent:隐藏层
no bn
:rtype: object
"""
W_v = self.W_v(avg_features)
W_v_h = self.W_v_h(h_sent.squeeze(1))
#获取图像特征的注意力参数
alpha_v = self.softmax(self.W_v_att(self.tanh(W_v + W_v_h)))
v_att = torch.mul(alpha_v, avg_features) #**这里没有使用sum?
#获取语言特征的注意力参数
W_a_h = self.W_a_h(h_sent)
W_a = self.W_a(semantic_features)
alpha_a = self.softmax(self.W_a_att(self.tanh(torch.add(W_a_h, W_a))))
a_att = torch.mul(alpha_a, semantic_features).sum(1)
#获取2个特征的共同注意力向量
ctx = self.W_fc(torch.cat([v_att, a_att], dim=1))
return ctx, alpha_v, alpha_a
4.SentenceLSTM
用LSTM对cxt生成下游report的输入向量topic和停止输出标识p_stop
class SentenceLSTM(nn.Module):
def v1(self, ctx, prev_hidden_state, states=None):
"""
v1 (only training)
:param ctx:上下文向量(来自co_attention)
:param prev_hidden_state: 自定义的隐藏层状态
:param states: lstm隐藏状态
:return:
"""
#lstm的序列长度为1,ctx扩展1维
ctx = ctx.unsqueeze(1)
hidden_state, states = self.lstm(ctx, states)
#topic用于生成report
topic = self.W_topic(self.sigmoid(self.bn_t_h(self.W_t_h(hidden_state))
+ self.bn_t_ctx(self.W_t_ctx(ctx))))
#p_stop结束输出标识
p_stop = self.W_stop(self.sigmoid(self.bn_stop_s_1(self.W_stop_s_1(prev_hidden_state))
+ self.bn_stop_s(self.W_stop_s(hidden_state))))
return topic, p_stop, hidden_state, states
5.WordLSTM
使用上游的生成topic向量生成报告信息
class WordLSTM(nn.Module):
def forward(self, topic_vec, captions):
#captions在训练时为报告标签
embeddings = self.embed(captions)
embeddings = torch.cat((topic_vec, embeddings), 1)
hidden, _ = self.lstm(embeddings)
outputs = self.linear(hidden[:, -1, :])
return outputs
训练部分在文件trainer.py中,具体看子类LSTMDebugger中,这里有个地方比较有意思:word_mode训练时把target_report和topic向量进行了融合了(transformer也是这样玩的)
def _epoch_train(self):
tag_loss, stop_loss, word_loss, loss = 0, 0, 0, 0
self.extractor.train()
self.mlc.train()
self.co_attention.train()
self.sentence_model.train()
self.word_model.train()
#取图片
for i, (images, _, label, captions, prob) in enumerate(self.train_data_loader):
batch_tag_loss, batch_stop_loss, batch_word_loss, batch_loss = 0, 0, 0, 0
images = self._to_var(images)
visual_features, avg_features = self.extractor.forward(images)
tags, semantic_features = self.mlc.forward(avg_features)
batch_tag_loss = self.mse_criterion(tags, self._to_var(label, requires_grad=False)).sum()
sentence_states = None
#prev_hidden_states在训练过程中没有进行更新,感觉有点突兀
prev_hidden_states = self._to_var(torch.zeros(images.shape[0], 1, self.args.hidden_size))
context = self._to_var(torch.Tensor(captions).long(), requires_grad=False)
prob_real = self._to_var(torch.Tensor(prob).long(), requires_grad=False)
#取一个CT中report中的每一段
for sentence_index in range(captions.shape[1]):
ctx, _, _ = self.co_attention.forward(avg_features,semantic_features,prev_hidden_states)
topic, p_stop, hidden_states, sentence_states = self.sentence_model.forward(ctx,prev_hidden_states,sentence_states)
#把不同report段的loss累加
batch_stop_loss += self.ce_criterion(p_stop.squeeze(), prob_real[:, sentence_index]).sum()
#取report中一段的词
for word_index in range(1, captions.shape[2]):
#这里比较有意思,训练时把target_report和topic融合了
words = self.word_model.forward(topic, context[:, sentence_index, :word_index])
word_mask = (context[:, sentence_index, word_index] > 0).float()
batch_word_loss += (self.ce_criterion(words, context[:, sentence_index, word_index])* word_mask).sum() * (0.9 ** word_index)
batch_loss = self.args.lambda_tag * batch_tag_loss \
+ self.args.lambda_stop * batch_stop_loss \
+ self.args.lambda_word * batch_word_loss
self.optimizer.zero_grad()
batch_loss.backward()