【PyTorch实战】Fully Connected Network

1. 简介

(1) 结构

【PyTorch实战】Fully Connected Network_第1张图片
简单的三层结构,第一层为输入层,第二层为隐藏层,第三层为输出层

(2) 激活函数


【PyTorch实战】Fully Connected Network_第2张图片

2.模型设计

(1) Model

import torch.nn as nn

from collections import OrderedDict

layers = OrderedDict()      # 创建顺序的dict结构

for i, n_hidden in enumerate(n_hiddens):

    layers['fc{}'.format(i+1)] = nn.Linear(current_dims, n_hidden) 

    layers['relu{}'.format(i+1)] = nn.ReLU()

    layers['drop{}'.format(i+1)] = nn.Dropout(0.2)

    current_dims = n_hidden

layers['out'] = nn.Linear(current_dims, n_class)

model = nn.Sequential(layers)    # 顺序的执行layers

print(model)

model = torch.nn.DataParallel(model, device_ids= range(ngpu)) # 数据并行

(2) Optimizer

# 采用随机梯度下降算法

# lr表示学习率

# weight_decay表示权重衰减,防止模型过拟合

# momentum加速模型的迭代,参见3

optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0001, momentum=0.9) 

for epoch in range(epochs):

    model.train() # 训练模式

    if epoch in [80,200]:

        optimizer.param_groups[0]['lr'] *= 0.1 # 随着模型跌打的次数增加,学习率降低

    for batch_idx, (data, target) in enumerate(train_data):

        data, target = Variable(data), Variable(target)

        optimizer.zero_grad() # 清除上一轮的梯度,否则会进行累加

        output = model(data) # 动态图结构

        loss = F.cross_entropy(output, target) # 交叉熵

        loss.backward()

        optimizer.step()

(3) Evaluation

# 模型训练时,按照epoch的次数进行效果评估

if epoch % 10 == 0:

    model.eval() #  评估模式

    test_loss = 0

    correct = 0

    for data, target in test_date:

        data, target = Variable(data, volatile=True), Variable(target)

        output = model(data)

        test_loss += F.cross_entropy(output, target).data[0]

        pred = output.data.max(1)[1]

        correct += pred.cpu().eq(indx_target).sum()

    test_loss = test_loss / len(test_loader)

    acc = 100. * correct / len(test_loader.dataset)

3. 参考资料

1. https://github.com/aaron-xichen/pytorch-playground

2. PyTorch API

3. On the importance of initialization and momentum in deep learning

4.Fully Connected Neural Network Algorithms

5.Fully Connected Neural Network与Activation Function

你可能感兴趣的:(【PyTorch实战】Fully Connected Network)