eg1: 相等的元素为True
torch.eq(torch.arange(5).view(5,1),torch.arange(5).view(5,1))
Out[28]:
tensor([[True],
[True],
[True],
[True],
[True]])
torch.eq(torch.arange(5).view(5,1),torch.arange(5).view(5,1).T)
Out[29]:
tensor([[ True, False, False, False, False],
[False, True, False, False, False],
[False, False, True, False, False],
[False, False, False, True, False],
[False, False, False, False, True]])
torch.eq(torch.arange(5).view(5,1).T,torch.arange(5).view(5,1))
Out[30]:
tensor([[ True, False, False, False, False],
[False, True, False, False, False],
[False, False, True, False, False],
[False, False, False, True, False],
[False, False, False, False, True]])
eg2: 解开指定的维度
torch.unbind(torch.tensor([[1],[2],[3]]),1)
Out[36]: (tensor([1, 2, 3]),)
torch.unbind(torch.tensor([[1],[2],[3]]),0)
Out[37]: (tensor([1]), tensor([2]), tensor([3]))
eg3:
torch.unbind(features,dim=1)
Out[44]:
(tensor([[ 0.0165, -0.1257, 0.0335, ..., 0.0430, 0.0588, 0.0256],
[ 0.0581, -0.0996, -0.0443, ..., -0.0111, 0.1081, -0.0078],
[ 0.0172, -0.1306, -0.0858, ..., -0.0411, 0.0833, 0.0013],
...,
[ 0.0601, -0.1264, -0.0413, ..., 0.0127, 0.1198, -0.0309],
[-0.0102, -0.1497, 0.0010, ..., -0.0122, 0.1112, -0.0583],
[ 0.0758, -0.1189, -0.0197, ..., 0.0220, 0.0872, -0.0166]],
device='cuda:0', grad_fn=),
tensor([[ 0.0165, -0.1257, 0.0335, ..., 0.0430, 0.0588, 0.0256],
[ 0.0581, -0.0996, -0.0443, ..., -0.0111, 0.1081, -0.0078],
[ 0.0172, -0.1306, -0.0858, ..., -0.0411, 0.0833, 0.0013],
...,
[ 0.0601, -0.1264, -0.0413, ..., 0.0127, 0.1198, -0.0309],
[-0.0102, -0.1497, 0.0010, ..., -0.0122, 0.1112, -0.0583],
[ 0.0758, -0.1189, -0.0197, ..., 0.0220, 0.0872, -0.0166]],
device='cuda:0', grad_fn=))
eg4: t.repeat()
eg5: torch.div()
eq6:
# tile mask mask = mask.repeat(anchor_count, contrast_count) # mask-out self-contrast cases logits_mask = torch.scatter( torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0 )
logits_mask
tensor([[0., 1., 1., ..., 1., 1., 1.],
[1., 0., 1., ..., 1., 1., 1.],
[1., 1., 0., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 0., 1., 1.],
[1., 1., 1., ..., 1., 0., 1.],
[1., 1., 1., ..., 1., 1., 0.]], device='cuda:0')
----
torch.scatter(
torch.ones_like(mask),
0,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0)
tensor([[0., 1., 1., ..., 1., 1., 1.],
[0., 1., 1., ..., 1., 1., 1.],
[0., 1., 1., ..., 1., 1., 1.],
...,
[0., 1., 1., ..., 1., 1., 1.],
[0., 1., 1., ..., 1., 1., 1.],
[0., 1., 1., ..., 1., 1., 1.]], device='cuda:0')
【笔记】scatter_函数:用法如 torch.zeros(target.size(0), 2).scatter_(1,target,1).to(self.device)_程序猿的探索之路的博客-CSDN博客
eg7: 消除对角元素
mask = mask.repeat(anchor_count, contrast_count) # mask-out self-contrast cases logits_mask = torch.scatter( torch.ones_like(mask), 1, torch.arange(batch_size * anchor_count).view(-1, 1).to(device), 0 ) mask = mask * logits_mask
mask
Out[20]:
tensor([[1., 0., 1., ..., 1., 0., 0.],
[0., 1., 0., ..., 0., 0., 0.],
[1., 0., 1., ..., 1., 0., 0.],
...,
[1., 0., 1., ..., 1., 0., 0.],
[0., 0., 0., ..., 0., 1., 0.],
[0., 0., 0., ..., 0., 0., 1.]], device='cuda:0')
mask*logits_mask
Out[21]:
tensor([[0., 0., 1., ..., 1., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[1., 0., 0., ..., 1., 0., 0.],
...,
[1., 0., 1., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]], device='cuda:0')
eg8:
2 64 loss = loss.view(anchor_count, batch_size)
tensor([[4.7642, 4.7700, 4.9274, 4.8298, 5.0163, 4.8054, 4.9344, 4.6297, 5.0533,
4.8786, 4.8808, 4.8946, 4.5965, 4.9085, 4.6794, 4.9939, 4.8648, 4.8382,
4.5422, 4.7529, 4.6383, 4.7940, 4.7202, 4.9732, 4.5696, 4.7187, 4.8346,
4.8804, 4.5355, 4.7395, 4.8884, 4.7580, 5.0020, 4.9140, 5.2952, 4.7402,
4.8660, 4.9400, 4.9015, 4.8370, 5.0518, 4.8339, 5.0241, 4.8498, 5.0187,
4.6112, 4.6124, 4.7228, 4.8453, 4.6810, 4.7281, 4.7040, 4.8005, 5.0514,
5.0573, 4.2868, 4.9171, 4.5031, 4.7733, 4.8827, 4.7193, 4.9463, 4.8855,
4.9188],
[4.7642, 4.7700, 4.9274, 4.8298, 5.0163, 4.8054, 4.9344, 4.6297, 5.0533,
4.8786, 4.8808, 4.8946, 4.5965, 4.9085, 4.6794, 4.9939, 4.8648, 4.8382,
4.5422, 4.7529, 4.6383, 4.7940, 4.7202, 4.9732, 4.5696, 4.7187, 4.8346,
4.8804, 4.5355, 4.7395, 4.8884, 4.7580, 5.0020, 4.9140, 5.2952, 4.7402,
4.8660, 4.9400, 4.9015, 4.8370, 5.0518, 4.8339, 5.0241, 4.8498, 5.0187,
4.6112, 4.6124, 4.7228, 4.8453, 4.6810, 4.7281, 4.7040, 4.8005, 5.0514,
5.0573, 4.2868, 4.9171, 4.5031, 4.7733, 4.8827, 4.7193, 4.9463, 4.8855,
4.9188]], device='cuda:0', grad_fn=
loss = loss.view(anchor_count, batch_size).mean()
tensor(4.8208, device='cuda:0', grad_fn=
loss.view(anchor_count, batch_size).mean(0)
tensor([4.7642, 4.7700, 4.9274, 4.8298, 5.0163, 4.8054, 4.9344, 4.6297, 5.0533,
4.8786, 4.8808, 4.8946, 4.5965, 4.9085, 4.6794, 4.9939, 4.8648, 4.8382,
4.5422, 4.7529, 4.6383, 4.7940, 4.7202, 4.9732, 4.5696, 4.7187, 4.8346,
4.8804, 4.5355, 4.7395, 4.8884, 4.7580, 5.0020, 4.9140, 5.2952, 4.7402,
4.8660, 4.9400, 4.9015, 4.8370, 5.0518, 4.8339, 5.0241, 4.8498, 5.0187,
4.6112, 4.6124, 4.7228, 4.8453, 4.6810, 4.7281, 4.7040, 4.8005, 5.0514,
5.0573, 4.2868, 4.9171, 4.5031, 4.7733, 4.8827, 4.7193, 4.9463, 4.8855,
4.9188], device='cuda:0', grad_fn=
loss.view(anchor_count, batch_size).mean(1)
tensor([4.8208, 4.8208], device='cuda:0', grad_fn=
class SupConLoss(nn.Module):
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
It also supports the unsupervised contrastive loss in SimCLR"""
def __init__(self, temperature=0.07, contrast_mode='all',
base_temperature=0.07):
super(SupConLoss, self).__init__()
self.temperature = temperature
self.contrast_mode = contrast_mode
self.base_temperature = base_temperature
def forward(self, features, labels=None, mask=None):
"""Compute loss for model. If both `labels` and `mask` are None,
it degenerates to SimCLR unsupervised loss:
https://arxiv.org/pdf/2002.05709.pdf
Args:
features: hidden vector of shape [bsz, n_views, ...].
labels: ground truth of shape [bsz].
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
has the same class as sample i. Can be asymmetric.
Returns:
A loss scalar.
"""
device = (torch.device('cuda')
if features.is_cuda
else torch.device('cpu'))
if len(features.shape) < 3:
raise ValueError('`features` needs to be [bsz, n_views, ...],'
'at least 3 dimensions are required')
if len(features.shape) > 3:
features = features.view(features.shape[0], features.shape[1], -1)
batch_size = features.shape[0]
if labels is not None and mask is not None:
raise ValueError('Cannot define both `labels` and `mask`')
elif labels is None and mask is None:
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
elif labels is not None:
labels = labels.contiguous().view(-1, 1)
if labels.shape[0] != batch_size:
raise ValueError('Num of labels does not match num of features')
mask = torch.eq(labels, labels.T).float().to(device)
else:
mask = mask.float().to(device)
contrast_count = features.shape[1]
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
if self.contrast_mode == 'one':
anchor_feature = features[:, 0]
anchor_count = 1
elif self.contrast_mode == 'all':
anchor_feature = contrast_feature
anchor_count = contrast_count
else:
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
# compute logits
anchor_dot_contrast = torch.div(
torch.matmul(anchor_feature, contrast_feature.T),
self.temperature)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# tile mask
mask = mask.repeat(anchor_count, contrast_count)
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# compute log_prob
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.view(anchor_count, batch_size).mean()
return loss