PyTorch 深度学习处理多维特征的输入

import numpy as np
import torch
import matplotlib.pyplot as plt

# prepare dataset
xy = np.loadtxt('diabetes.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])  # 第一个‘:’是指读取所有行,第二个‘:’是指从第一列开始,最后一列不要
print("input data.shape", x_data.shape)
y_data = torch.from_numpy(xy[:, [-1]])  # [-1] 最后得到的是个矩阵


# print(x_data.shape)
# design model using class


class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6)
        self.linear2 = torch.nn.Linear(6, 4)
        self.linear3 = torch.nn.Linear(4, 2)
        self.linear4 = torch.nn.Linear(2, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x))  # y hat
        x = self.sigmoid(self.linear4(x))  # y hat
        return x


model = Model()

# construct loss and optimizer
# criterion = torch.nn.BCELoss(size_average = True)
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

loss_list = []
acc_list = []
# training cycle forward, backward, update
for epoch in range(100):
    y_pred = model(x_data)  # 前向传播
    loss = criterion(y_pred, y_data)  # 计算损失函数

    optimizer.zero_grad()  # 梯度清零
    loss.backward()  # 反向传播
    optimizer.step()  # 更新参数

    y_pred_label = torch.where(y_pred >= 0.5, torch.tensor([1.0]), torch.tensor([0.0]))  # 将预测值大于0.5的置为1,小于0.5的置为0
    acc = torch.eq(y_pred_label, y_data).sum().item() / y_data.size(0)  # 计算准确率
    print("loss = ", loss.item(), "acc = ", acc)
    loss_list.append(loss.item())  # 将损失函数值加入列表中
    acc_list.append(acc)  # 将准确率加入列表中

# 绘制损失函数和准确率
epoch = np.arange(0, 100, 1)
plt.scatter(epoch, loss_list, c='r')  # 绘制损失函数,红色
plt.scatter(epoch, acc_list, c='b')  # 绘制准确率,蓝色
plt.pause(0.0001)
plt.show()

PyTorch 深度学习处理多维特征的输入_第1张图片

diabetes.csv.gz数据集下载:

链接:https://pan.baidu.com/s/1vZ27gKp8Pl-qICn_p2PaSw
提取码:cxe4

你可能感兴趣的:(深度学习,pytorch,人工智能)