flag = True
if flag:
t = torch.ones((2,3))
t_0 = torch.cat([t,t],dim=0)
t_1 = torch.cat([t,t],dim=1)
print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0,t_0.shape,t_1,t_1.shape))
第二个例子:
flag = True
if flag:
t = torch.ones((2,3))
t_stack = torch.stack([t,t],dim=2)
print("t_stack:{} shape:{}".format(t_stack,t_stack.shape))
结果为:
这个理解起来可能有点费劲,但是没关系,好理解的博客我已经给你找好了,
看看传送门~~~ 这篇,可能会更好理解一点。
flag = True
if flag:
a = torch.ones((2,5))
list_of_tensors = torch.chunk(a,dim=1,chunks=2)
for idx,t in enumerate(list_of_tensors):
print("第{}个张量:{},shape is {}".format(idx+1,t,t.shape))
flag = True
if flag:
t = torch.ones((2,5))
list_of_tensors = torch.split(t,2,dim=1)
for idx,t in enumerate(list_of_tensors):
print("第{}个张量:{},shape is {}".format(idx+1,t,t.shape))
flag = True
if flag:
t = torch.randint(0,9,size=(3,3))
idx = torch.tensor([0,2],dtype=torch.long) #一般这里只能用long类型
t_select = torch.index_select(t,dim=0,index=idx)
print("t:\n{}\nt_select:\n{}".format(t,t_select))
flag = True
if flag:
t = torch.randint(0,9,size=(3,3))
mask = t.ge(5) #指生成数组中大于等于5的元素为True
t_select = torch.masked_select(t,mask)
print("t:\n{}\nmask:\n{}\nt_select:\n{}".format(t,mask,t_select))
flag = True
if flag:
t = torch.rand((1,2,3,1))
t_sq = torch.squeeze(t)
t_0 = torch.squeeze(t,dim=0)
t_1 = torch.squeeze(t,dim=1)
print(t.shape)
print(t_sq.shape)
print(t_0.shape)
print(t_1.shape)