Pytorch学习笔记(三)

(6)在Pytorch中实现自己定义的层:
在Pytorch中实现自己定义的层需要继承torch.autograd.Function类,然后实现其中的forward和backward方法,代码如下所示:

# -*- coding:utf-8 -*-
import torch
from torch.autograd import Variable

class MyReLU(torch.autograd.Function):
    def forward(self, input):
        self.save_for_backward(input)
        return input.clamp(min=0)

    def backward(self, grad_output):
        input, = self.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input


dtype = torch.FloatTensor
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold input and outputs, and wrap them in Variables.
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

# Create random Tensors for weights, and wrap them in Variables.
w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # Construct an instance of our MyReLU class to use in our network
    relu = MyReLU()

    # Forward pass: compute predicted y using operations on Variables; we compute
    # ReLU using our custom autograd operation.
    y_pred = relu(x.mm(w1)).mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.data[0])

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data


    w1.grad.data.zero_()
    w2.grad.data.zero_()

运行结果如下:

0 27789192.0
1 21618462.0
2 17495414.0
3 13794485.0
4 10308335.0
5 7368669.5
6 5121090.5
7 3554541.75
8 2508995.25
9 1826433.625
10 1376494.0
11 1072881.0
12 860489.9375
13 706387.0625
14 590104.375
15 499442.3125
...

(7)动态图与静态图的对比:
前面提到了Pytorch是属于动态图计算的框架,而Tensorflow是属于静态图计算的框架(其实最新发布的Tensorflow Fold也增加了对动态图的支持),下面介绍一下静态图和动态图的概念。
首先,我们要搞清楚深度学习框架所谓的“动态”和“静态”究竟是按照什么标准划分的。
在静态框架使用的是静态声明 (static declaration)策略,计算图的声明和执行是分开的,换成比喻的说法就是:建筑设计师画建筑设计图(声明)和施工队建造房子(执行)是分开进行的。画设计图的时候施工队的建筑工人、材料和机器都还没动,这也就是我们说的静态。这个整个声明和执行的过程中涉及到两个图,这里我们分别给它们一个名字,声明阶段构建的图叫虚拟计算图,在这个过程中框架需要将用户的代码转化为可以一份详细的计算图,这份计算图一般会包含计算执行顺序和内存空间分配的策略,这些策略的制定一般是这个过程中最消耗时间的部分;执行阶段构建的图叫实体计算图,这个过程包括为参数和中间结果实际分配内存空间,并按照当前需求进行计算等,数据就在这张实体计算图中计算和传递。
而动态框架则不同,使用的是动态声明(dynamic declaration)策略,声明和执行一起进行的。比喻一下就是设计师和施工队是一起工作的,设计师看邮件的第一句如“要有一个二十平方米的卧室”,马上画出这个卧室的设计图交给施工队建造,然后再去看第二句。这样虚拟计算图和实体计算图的构建就是同步进行的了。因为可以实时地计划,动态框架可以根据实时需求构建对应的计算图,在灵活性上,动态框架会更胜一筹。
个人觉得总结一下就是静态图对于网络的设计可能更加的复杂一些,但是静态图的效率更高,也更适合进行分布式部署,所以比较适合工程使用;动态图的优势在于快速实现网络结构,效率上可能不如静态图,比较适合research领域使用。
下面看一段体现动态图优势的代码:

# -*- coding:utf-8 -*-
# ControlFlow + Weight Sharing
import random
import torch
from torch.autograd import Variable
import torch.nn as nn

class DynamicNet(nn.Module):
    def __init__(self, D_in, H, D_out):
        super(DynamicNet, self).__init__()
        self.input_linear = nn.Linear(D_in, H)
        self.middle_linear = nn.Linear(H, H)
        self.output_linear = nn.Linear(H, D_out)

    def forward(self, x):
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

N, D_in, H, D_out = 64, 1000, 100, 10
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out), requires_grad=False)

model = DynamicNet(D_in, H, D_out)
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
for t in range(500):
    y_pred = model(x)
    loss = criterion(y_pred, y)
    print(t, loss.data[0])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

运行结果如下:

0 673.64990234375
1 664.180419921875
2 668.1372680664062
3 669.5772094726562
4 666.9349975585938
5 663.29736328125
6 604.0421142578125
7 655.1654052734375
8 581.149658203125
9 566.9959716796875
10 659.7771606445312
...

你可能感兴趣的:(Pytorch)