6.PyTorch实现逻辑回归(多分类)

1 准备数据

import torch
import matplotlib.pyplot as plt
import numpy as np
xy = np.loadtxt('./资料/data/diabetes.csv.gz', delimiter=',', dtype=np.float32)

# 第一个‘:’是指读取所有行,第二个‘:’是指从第一列开始,最后一列不要
x_data = torch.from_numpy(xy[:,:-1])

# [-1] 最后得到的是个矩阵
y_data = torch.from_numpy(xy[:, [-1]])

2 构建模型

class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear1 = torch.nn.Linear(8, 6) # 输入数据x的特征是8维,x有8个特征
        self.linear2 = torch.nn.Linear(6, 4) # 6个输入,4个输出
        self.linear3 = torch.nn.Linear(4, 1) # 4个输入,1个输出
        self.sigmoid = torch.nn.Sigmoid() # 最后一层,将输出结构映射到sigmoid函数中
 
    def forward(self, x):
        x = self.sigmoid(self.linear1(x))
        x = self.sigmoid(self.linear2(x))
        x = self.sigmoid(self.linear3(x)) # y hat
        return x
 
 
model = Model()

3 模型训练

# construct loss and optimizer
# criterion = torch.nn.BCELoss(size_average = True)
criterion = torch.nn.BCELoss(reduction='mean')  
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
 
epoch_list = []
loss_list = []
# training cycle forward, backward, update
for epoch in range(100):
    # 1. Forward
    y_pred = model(x_data)
    loss = criterion(y_pred, y_data)
    print(epoch, loss.item())
    epoch_list.append(epoch)
    loss_list.append(loss.item())
    
    # 2. Backward
    optimizer.zero_grad()
    loss.backward()
    
    # 3. Update
    optimizer.step()
0 0.6853322982788086
1 0.6812020540237427
2 0.6775144338607788
3 0.6742204427719116
4 0.6712767481803894
5 0.668644905090332
6 0.6662906408309937
7 0.6641837954521179
8 0.6622974872589111
9 0.6606078743934631
10 0.6590937972068787
11 0.6577363610267639
12 0.6565189957618713
13 0.6554266810417175
14 0.6544461250305176
15 0.6535657048225403
16 0.6527747511863708
17 0.6520640850067139
18 0.6514251828193665
19 0.6508506536483765
20 0.6503338813781738
21 0.649868905544281
22 0.6494503021240234
23 0.6490734815597534
24 0.6487340927124023
25 0.6484283804893494
26 0.6481528282165527
27 0.6479045152664185
28 0.6476806998252869
29 0.6474788188934326
30 0.6472967267036438
31 0.6471324563026428
32 0.6469842791557312
33 0.6468505263328552
34 0.6467297077178955
35 0.6466206312179565
36 0.6465222239494324
37 0.6464332342147827
38 0.6463528275489807
39 0.6462802290916443
40 0.6462146043777466
41 0.6461552381515503
42 0.6461015343666077
43 0.6460530161857605
44 0.646009087562561
45 0.6459693908691406
46 0.6459333896636963
47 0.6459008455276489
48 0.6458713412284851
49 0.6458446979522705
50 0.6458204984664917
51 0.6457985639572144
52 0.6457787156105042
53 0.6457606554031372
54 0.6457443237304688
55 0.6457294821739197
56 0.6457160115242004
57 0.6457038521766663
58 0.6456927061080933
59 0.6456825733184814
60 0.645673394203186
61 0.6456650495529175
62 0.6456573605537415
63 0.6456504464149475
64 0.6456441283226013
65 0.6456383466720581
66 0.6456330418586731
67 0.6456282138824463
68 0.6456237435340881
69 0.6456197500228882
70 0.6456159949302673
71 0.6456125974655151
72 0.645609438419342
73 0.6456065773963928
74 0.6456038951873779
75 0.6456013917922974
76 0.6455991864204407
77 0.6455971002578735
78 0.6455950736999512
79 0.6455932259559631
80 0.6455916166305542
81 0.6455900073051453
82 0.6455884575843811
83 0.6455870270729065
84 0.6455857157707214
85 0.6455844640731812
86 0.6455833315849304
87 0.6455821990966797
88 0.6455811262130737
89 0.6455801725387573
90 0.6455791592597961
91 0.6455782651901245
92 0.6455774307250977
93 0.645576536655426
94 0.645575761795044
95 0.6455748677253723
96 0.645574152469635
97 0.6455734372138977
98 0.6455726623535156
99 0.6455719470977783

你可能感兴趣的:(#,PyTorch)