每个iteration包含多个batch,也就是多个eposide;每个eposide包含随机的classes_per_it个类别,每个类别包含随机选择的sample_per_class个样本组成support set,query set由这些类中的一个随机类的一个随机样本组成。由于这些样本是作为一个序列输入到模型中的,所以最后一个样本即为query set,也就是要预测标签的样本。输入时,将一个batch中的所有eposide的样本拼接起来一起输入。
将图像输入到时序卷积网络前,先要对图像做特征提取
class CasualConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, dilation=1, groups=1, bias=True):
super(CasualConv1d, self).__init__()
self.dilation = dilation
padding = dilation * (kernel_size - 1)
self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
def forward(self, input):
# Takes something of shape (N, in_channels, T),
# returns (N, out_channels, T)
out = self.conv1d(input)
return out[:, :, :-self.dilation] #
class DenseBlock(nn.Module):
def __init__(self, in_channels, dilation, filters, kernel_size=2):
super(DenseBlock, self).__init__()
self.casualconv1 = CasualConv1d(in_channels, filters, kernel_size, dilation=dilation)
self.casualconv2 = CasualConv1d(in_channels, filters, kernel_size, dilation=dilation)
def forward(self, input):
# input is dimensions (N, in_channels, T)
xf = self.casualconv1(input)
xg = self.casualconv2(input)
activations = F.tanh(xf) * F.sigmoid(xg) # shape: (N, filters, T)
return torch.cat((input, activations), dim=1)
class TCBlock(nn.Module):
def __init__(self, in_channels, seq_length, filters):
super(TCBlock, self).__init__()
self.dense_blocks = nn.ModuleList([DenseBlock(in_channels + i * filters, 2 ** (i+1), filters) for i in range(int(math.ceil(math.log(seq_length))))])
def forward(self, input):
# input is dimensions (N, T, in_channels)
input = torch.transpose(input, 1, 2)
for block in self.dense_blocks:
input = block(input)
return torch.transpose(input, 1, 2)
soft attention可以让模型在可能的无限大的上下文中精确的定位信息,把上下文信息当做无序的键值对,通过内容对其进行查找。
class AttentionBlock(nn.Module):
def __init__(self, in_channels, key_size, value_size):
super(AttentionBlock, self).__init__()
self.linear_query = nn.Linear(in_channels, key_size)
self.linear_keys = nn.Linear(in_channels, key_size)
self.linear_values = nn.Linear(in_channels, value_size)
self.sqrt_key_size = math.sqrt(key_size)
def forward(self, input):
# input is dim (N, T, in_channels) where N is the batch_size, and T is
# the sequence length
mask = np.array([[1 if i>j else 0 for i in range(input.shape[1])] for j in range(input.shape[1])])
mask = torch.ByteTensor(mask).cuda()
#import pdb; pdb.set_trace()
keys = self.linear_keys(input) # shape: (N, T, key_size)
query = self.linear_query(input) # shape: (N, T, key_size)
values = self.linear_values(input) # shape: (N, T, value_size)
temp = torch.bmm(query, torch.transpose(keys, 1, 2)) # shape: (N, T, T)
temp.data.masked_fill_(mask, -float('inf'))
temp = F.softmax(temp / self.sqrt_key_size, dim=1) # shape: (N, T, T), broadcasting over any slice [:, x, :], each row of the matrix
temp = torch.bmm(temp, values) # shape: (N, T, value_size)
return torch.cat((input, temp), dim=2) # shape: (N, T, in_channels + value_size)
过程与PrototpicalNet相同
Model | 5-way 1-shot Acc. | 5-way 5-shot Acc. | 20-way 1-shot Acc. | 20-way 5-shot Acc. |
---|---|---|---|---|
Reference Paper | 99.07% | 99.78% | 97.64% | 99.36% |
This repo | 98.31% | 99.26% | 93.75% | 97.88% |