Dropout判断可以抵抗过拟合的方法

Dropout判断可以抵抗过拟合的方法_第1张图片
Dropout判断可以抵抗过拟合的方法_第2张图片
Dropout判断可以抵抗过拟合的方法_第3张图片

Dropout

LR=0.5
model = Net()
mse_loss = nn.CrossEntropyLoss()
#定义优化器,设置正则化L2
optimizer=optim.SGD(model.parameters(),LR,weight_decay=0.001)

def train(): #调用一次,训练一个周期
    model.train()  # dropout 起作用
    for i,data in enumerate(train_loader):
        #获得一个批次的数据和标签
        inputs, labels = data
        #(64,10)
        out = model(inputs)
#         #把数据标签变成独热编码
#         #[64]-->[64,1]
#         labels = labels.reshape(-1,1)
#         #tensor.scatter(dim,index,src)
#         #dim:对哪个维度进行独热编码
#         #index:要将src中对用的值放到tensor中的哪个位置
#         #src:插入index的数值
#         #将label转为one_hot编码1-->[1,0,0,0,0,0,0,0,0,0]
#         one_hot = torch.zeros(inputs.shape[0],10).scatter(1,labels,1) #将1放到labels中的哪个位置
        #计算loss
        loss = mse_loss(out,labels)
        #梯度清零
        optimizer.zero_grad()
        #计算梯度
        loss.backward()
        #更新权值
        optimizer.step()
def test():
    #测试集准确率
    model.eval() #dropout不工作
    correct=0
    for i,data in enumerate(test_loader):
        inputs,labels = data
        out = model(inputs)
        _,predicted = torch.max(out,1)
        correct += (predicted==labels).sum()
    print("Test acc:{0}".format(correct.item()/len(test_dataset)))
    print(correct.item())
    print(len(test_dataset))
    #训练集准确率
    correct1=0
    for i,data in enumerate(train_loader):
        inputs,labels = data
        out = model(inputs)
        _,predicted = torch.max(out,1)
        correct1 += (predicted==labels).sum()
    print("Train acc:{0}".format(correct1.item()/len(train_dataset)))

Dropout:对比训练集和测试集的准确率,取差值;对比没有使用Dropout=0的差值。可以说明其可抵抗过拟合(差值越小越好)
正则化:可以抵抗过拟合。

你可能感兴趣的:(PyTorch)