pytorch(二)--张量的操作汇总

一,张量的拼接
pytorch(二)--张量的操作汇总_第1张图片我举两个例子来分别说明下:

flag = True

if flag:
        t = torch.ones((2,3))

        t_0 = torch.cat([t,t],dim=0)
        t_1 = torch.cat([t,t],dim=1)

        print("t_0:{} shape:{}\nt_1:{} shape:{}".format(t_0,t_0.shape,t_1,t_1.shape))

运行处的结果为:pytorch(二)--张量的操作汇总_第2张图片
这一个很好理解。

第二个例子:

flag = True

if flag:
    t = torch.ones((2,3))

    t_stack = torch.stack([t,t],dim=2)

    print("t_stack:{} shape:{}".format(t_stack,t_stack.shape))

结果为:
pytorch(二)--张量的操作汇总_第3张图片
这个理解起来可能有点费劲,但是没关系,好理解的博客我已经给你找好了,
看看传送门~~~ 这篇,可能会更好理解一点。

二,张量的切分
第一种方法:
pytorch(二)--张量的操作汇总_第4张图片举个例子:

flag = True

if flag:
    a = torch.ones((2,5))
    list_of_tensors = torch.chunk(a,dim=1,chunks=2)

    for idx,t in enumerate(list_of_tensors):
        print("第{}个张量:{},shape is {}".format(idx+1,t,t.shape))

pytorch(二)--张量的操作汇总_第5张图片

第二种方法:
pytorch(二)--张量的操作汇总_第6张图片举个例子:

flag = True

if flag:
    t = torch.ones((2,5))

    list_of_tensors = torch.split(t,2,dim=1)
    for idx,t in enumerate(list_of_tensors):
        print("第{}个张量:{},shape is {}".format(idx+1,t,t.shape))


结果:
pytorch(二)--张量的操作汇总_第7张图片
三,张量的索引
第一种方法:
pytorch(二)--张量的操作汇总_第8张图片
举个例子:


flag = True
if flag:
    t = torch.randint(0,9,size=(3,3))
    idx = torch.tensor([0,2],dtype=torch.long) #一般这里只能用long类型
    t_select = torch.index_select(t,dim=0,index=idx)
    print("t:\n{}\nt_select:\n{}".format(t,t_select))

结果:
pytorch(二)--张量的操作汇总_第9张图片

第二种方法:
pytorch(二)--张量的操作汇总_第10张图片举个例子:

flag = True
if flag:
    t = torch.randint(0,9,size=(3,3))
    mask = t.ge(5)  #指生成数组中大于等于5的元素为True
    t_select = torch.masked_select(t,mask)
    print("t:\n{}\nmask:\n{}\nt_select:\n{}".format(t,mask,t_select))

结果为:
pytorch(二)--张量的操作汇总_第11张图片
四,张量变换
第一种方法:
pytorch(二)--张量的操作汇总_第12张图片这个太简单就不举例了

第二,三 种方法:
pytorch(二)--张量的操作汇总_第13张图片
第四,五种方法:
pytorch(二)--张量的操作汇总_第14张图片举个例子:

flag = True
if flag:
    t = torch.rand((1,2,3,1))
    t_sq = torch.squeeze(t)
    t_0 = torch.squeeze(t,dim=0)
    t_1 = torch.squeeze(t,dim=1)
    print(t.shape)
    print(t_sq.shape)
    print(t_0.shape)
    print(t_1.shape)

结果为:
pytorch(二)--张量的操作汇总_第15张图片
五,有点特殊的张量数学运算
pytorch(二)--张量的操作汇总_第16张图片

你可能感兴趣的:(Pytorch,pytorch,深度学习)