t = torch.ones((2,5))
list_of_tensors = torch.split(t,[2,1,1],dim=1)
print("t:{}".format(t))
for idx,t in enumerate(list_of_tensors):
print("di {} ge tensor:{},shape is {}".format(idx+1,t,t.shape))
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 5 (input tensor's size at dimension 1), but got split_sizes=[2, 1, 1]
是说list_of_tensors = torch.split(t,[2,1,1],dim=1)中,输入tensor t 的第一维的维度是5,但是split_size=[2, 1, 1]的和是4不是5,改为和为5才不会报错。
t_select = torch.index_select(t,dim=1,index=idx)
t = torch.randint(0,9,(3,3))
idx = torch.tensor([0,2],dtype=torch.long)
t_select = torch.index_select(t,dim=1,index=idx)
print("t:\n{}\nt_select:\n{}".format(t,t_select))
输出:
t:
tensor([[4, 6, 6],
[1, 5, 6],
[0, 0, 8]])
t_select:
tensor([[4, 6],
[1, 6],
[0, 8]])
t_select = torch.index_select(t,dim=0,index=idx)
t = torch.randint(0,9,(3,3))
idx = torch.tensor([0,2],dtype=torch.long)
t_select = torch.index_select(t,dim=0,index=idx)
print("t:\n{}\nt_select:\n{}".format(t,t_select))
输出:
t:
tensor([[0, 6, 0],
[4, 1, 0],
[2, 2, 6]])
t_select:
tensor([[0, 6, 0],
[2, 2, 6]])
idx = torch.tensor([0,2],dtype=torch.float)
t = torch.randint(0,9,(3,3))
idx = torch.tensor([0,2],dtype=torch.float)
t_select = torch.index_select(t,dim=0,index=idx)
print("t:\n{}\nt_select:\n{}".format(t,t_select))
报错:
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #3 'index'
t_select = torch.index_select(t,dim=0,index=idx)中的index应该是long类型,设置为float类型会报错,idx = torch.tensor([0,2],dtype=torch.long)改为long类型就不会报错了。
import torch
torch.manual_seed(10)
flag = True
# flag = False
if flag:
w = torch.tensor([1.],requires_grad=True)
x = torch.tensor([2.],requires_grad=True)
a = torch.add(w,x)
a.retain_grad()
b = torch.add(w,1)
y = torch.mul(a,b)
y.backward()
# print(w.grad)
y.backward()
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
进行了两次反向传播,retain_graph默认为False,在第一次反向传播之后计算图会被释放,如果想继续使用计算图,要把retain_graph设置为True。
y.backward(retain_graph=True)
# print(w.grad)
y.backward()