一、Pytouch 计算图、张量、变量

环境 conda 3.x, Python3.7,torch 1.3.0+cpu

一、计算图

1、计算图示意图

https://www.cnblogs.com/catnip/p/8760780.html

import torch
from torch.autograd import Variable

if __name__ == '__main__':
    x = Variable(torch.randn(1,10))
    prev_h = Variable(torch.randn(1,20))
    w_h = Variable(torch.randn(20,20))
    w_x = Variable(torch.randn(20,10))
    print(x)
    print(x.t())#转置矩阵
    # torch.mul(a, b)    是矩阵a和b对应位相乘,a和b的维度必须相等
    # torch.mm(a, b)是矩阵a和b矩阵相乘,
    #
    i2h = torch.mm(w_x,x.t())
    h2h = torch.mm(w_h,prev_h.t())
    next_h = i2h + h2h
    next_h = next_h.tanh()
    print(i2h)

(1)、这个图里有两种节点:Variable节点和Function节点

Variable记录运算数据,Function记录运算操作。

其中Variable节点又可以分为叶节点和非叶节点两类。叶节点由用户直接创建产生,而非叶节点则由Variable节点之间的运算操作产生,在图的代码中,x、prev_h、W_h、W_x属于叶节点,i2h、h2h、next_h属于非叶节点。 

(2)、在这个图上,节点之间的关系是很明确的:

Variable非叶节点指向产生它的Function,因为产生某个Variable的Function只可能有一个,因此一个Variable只指向一个Function。

Function的指向则是可以一对多的,因为一个运算函数往往可以接受大量的参数。

Function指向两种节点,当Function接受一个叶节点的Variable输入时,Function需指向此Variable,当Function接受一个非叶节点Variable输入时,Function需指向此Variable所指向的那个Function。

二、张量 tensor

tensor的数据类型:

  • 32 位浮点型 torch.Float Tensor

  • 64 位浮点型 torch.DoubleTensor

  • 16 位整型 torch.Shor tTensor

  • 32 位 整型 torch.lntTensor 

  • 64 位整型 torch.LongTensor, torch.Tensor

  • 默认的是 torch.FloatTensor 数据类型。

# encoding: utf-8

import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable


def set_tensor():
    """1、定义一个tensor张量"""
    list = [[2, 3],[4, 8],[7, 9]]
    res = torch.Tensor(list)
    print('a is: {}'.format(a))  # 一种输出方式:{}的内容为format括号里的内容
    print('a size is {}'.format(a.size()))  # 输出矩阵的维度
    return res

def other_tensor():
    """2、其他常用方式定义tensor矩阵"""
    list = [[2, 3],[4, 8],[7, 9]]
    x = torch.LongTensor(list)  # 64位整型的tensor
    x = torch.zeros((3, 2))  # 定义(3,2)的零矩阵
    x = torch.randn((3, 2))  # 定义(3,2)的正态分布随机矩阵
    x = torch.ones(2, 2)  # 构建数值均为1的矩阵
    x = torch.eye(3, 3)  # 构建对角线为一的矩阵
    print(x)
    return x

def matrix():
    """3、矩阵的数据类型转化"""
    a = torch.Tensor([[2, 3], [4, 8], [7, 9]])
    a = a.float()  # 转为FloatTensor
    a = a.double()  # 转为DoubleTensor
    a = a.short()  # 转为ShortTensor
    a = a.int()  # 转为IntTensor
    a = a.long()  # 转为LongTensor
    return a

def trans_np_tensor():
    """4、tensor与numpy的转化(某些包(如matplotlib)只识别numpy的矩阵)"""
    a = torch.Tensor([[2, 3], [4, 8], [7, 9]])  # a为tensor型
    b = np.array(a)  # b为numpy型
    c = torch.from_numpy(b)  # c为tensor型
    return c

if __name__ == '__main__':
    a = set_tensor()
    b = other_tensor()
    c = matrix()
    d = trans_np_tensor()

一、Pytouch 计算图、张量、变量_第1张图片

三、Variable 变量

  • Variable就是变量, Variable 提供了自动求导的功能

  • Variable 和 Tensor 本质上没有区别,不过 Variable 会被放入一个计算图中,然后进行前向传播,反向传播,自动求导(神经网络基本的需求)。

# encoding: utf-8

import torch
import numpy as np
import matplotlib.pyplot as plt
import pytouch
from torch.autograd import Variable


def automatic_derivation():
    """自动求导 """
    # Create Varìable
    # requires_grad = True ,这个参数表示是否对这个变量求梯度,默认的是 Fa!se ,
    # 也就是不对这个变量求梯度,这里我们希望得到这些变量的梯度,所以需要传入这个参数。
    X = Variable(torch.Tensor([1]), requires_grad=True)
    W = Variable(torch.Tensor([2]), requires_grad=True)
    b = Variable(torch.Tensor([3]), requires_grad=True)
    Y = W * X + b  # y=2*x + 3
    Y.backward()  # y函数反向传递、自动求导
    # x、w、b在x = 1、w = 2、b=3下的对x w b 的导数
    print(X.grad)  # x.grad = 2
    print(W.grad)  # w.grad = 1
    print(b.grad)  # b.grad =1
    # result
    # tensor([2.])
    # tensor([1.])
    # tensor([1.])

def matrix_automatic_derivation():
    """对矩阵求导"""
    # 相当于给出了一个三维向量去做运算,这时候得到的结果 ν 就是一个向量
    # 这里对这个向量求导就不能直接写成 y.backward(),这样程序是会报错的。
    # 这个时候需要传入参数声明,比如 y.backward(torch.FloatTensor 1,1, 1) )).
    # 这样得到的结果就是它们每个分量的梯度,
    # 或者可以传入 y.backward(torch.FloatTensor( [1, 0.1 , 0. 01] )) ,
    # 这样得到的梯度就是它们原本的梯度分别乘上 1 , 0.1 和 0.01。
    x = torch.randn(3)
    x = Variable(x, requires_grad=True)
    print(x)
    y = x * 2
    print(y)
    y.backward(torch.tensor([1, 0.1, 0.01]))
    print(x.grad)

if __name__ == '__main__':
    automatic_derivation()
    matrix_automatic_derivation()

 

你可能感兴趣的:(Pytorch)