torch.repeat vs torch.expand

torch.repeat

import torch

v_sum = torch.tensor([[1,3,0],[1,0,0]])
print(v_sum)

mask_index = torch.nonzero(v_sum == 0)
print(mask_index)

q = torch.rand([2,1,6])
print(q)

q_expand = q.repeat(1,3,1)
print(q_expand)

q_expand[mask_index[:, 0], mask_index[:, 1]] = 0
print(q_expand)

torch.repeat vs torch.expand_第1张图片

torch.expand

import torch

v_sum = torch.tensor([[1,3,0],[1,1,1]])
print(v_sum)

mask_index = torch.nonzero(v_sum == 0)
print(mask_index)

q = torch.rand([2,1,6])
print(q)

q_expand = q.expand(-1,3,-1) # q_expand = q.expand(2,3,6)
print(q_expand)

q_expand[mask_index[:, 0], mask_index[:, 1]] = 0
print(q_expand)

torch.repeat vs torch.expand_第2张图片

你可能感兴趣的:(pytorch)