pytorch中CNN网络的实现

1.加载数据集

这里我们加载的是mnist数据集,这里我直接下载下来的了。

import torch.utils.data as Data
import matplotlib.pyplot as plt
import numpy as np

# import keras
# plt.ion()

data = np.load('./data/mnist.npz')
# print(data.files)
X_test = data[data.files[0]][:1000]
X_train = data[data.files[1]][:20000]
y_train = data[data.files[2]][:20000]
y_test = data[data.files[3]][:1000]

# to Tensor
X_train = torch.Tensor(X_train)
X_train = X_train.unsqueeze(dim=1)
y_train = torch.Tensor(y_train).long()
X_test = torch.Tensor(X_test).unsqueeze(dim=1)
y_test = torch.Tensor(y_test).long()

2.模型的搭建

# hyper parameters
BATCH_SIZE = 64  # batch_size
LEARNING_RATE = 0.02  # learning_rate
EPOCH = 2  # epochs

torch_dataset = Data.TensorDataset(X_train, y_train)
loader = Data.DataLoader(
    dataset=torch_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    # num_workers=2
)


class CNN(torch.nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=1,  # channels = 1 for the black photo
                out_channels=16,
                kernel_size=3,
                stride=1,
                padding=1,  # as like padding='same' in keras ,the formula is (kernel_size-stride)/2
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2)  # shape = [16,14,14] for one photo
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2)  # shape = [32,7,7] for one photo
        )
        self.softmax = torch.nn.Softmax()
        self.predict = torch.nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.predict(x)
        return x

net = CNN()
print(net)
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE,betas=(0.8,0.8))
loss_func = torch.nn.CrossEntropyLoss()

3.模型训练以及可视化

# record loss and accuracy to draw photo
global_loss_train = []
global_loss_test = []
global_acc_train = []
global_acc_test = []

for epoch in range(EPOCH):
    print('----------------------epoch------------------', epoch)
    for step, (batch_x, batch_y) in enumerate(loader):
        prediction = net.forward(batch_x)
        loss = loss_func(prediction, batch_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step % 36 == 0:
            print('loss:{:.2f}'.format(loss.data.numpy()))
            prediction = net.forward(X_test[:1000])
            y_real = y_test[:1000]
            prediction = torch.argmax(prediction, 1)
            test_acc = (prediction == y_real).sum().item() / len(y_real)
            print('test-acc:{:.2f}'.format(test_acc))
    if epoch % 1 == 0:
        print('----------------epoch end---------------------', epoch)
        prediction_train = net.forward(X_train[:1000])
        loss_train = loss_func(prediction_train, y_train[:1000]).item()
        global_loss_train.append(loss_train)
        prediction_train = torch.argmax(prediction_train, 1)
        train_acc = (prediction_train == y_train[:1000]).sum().item() / len(y_train[:1000])
        global_acc_train.append(train_acc)
        print('|loss_train|', loss_train, '|train_acc|', train_acc)
        print(global_loss_train)
        prediction_test = net.forward(X_test[:1000])
        loss_test = loss_func(prediction_test, y_test[:1000]).item()
        global_loss_test.append(loss_test)
        prediction_test = torch.argmax(prediction_test, 1)
        acc_test = (prediction_test == y_test[:1000]).sum().item() / len(y_test)
        global_acc_test.append(acc_test)
        print('|loss_test|', loss_test, '|test_acc|', acc_test)


epochs = range(1, len(global_acc_test) + 1)

plt.plot(epochs, global_acc_train, 'bo', label='Train')
plt.plot(epochs, global_acc_test, 'r', label='test')
plt.title('acc')
plt.legend()
plt.figure()
plt.plot(epochs, global_loss_train, 'bo', label='Train')
plt.plot(epochs, global_loss_test, 'r', label='test')
plt.legend()
plt.title('loss')
plt.show()

这里最后做的图是每个epoch对应的准确率和损失。
pytorch中CNN网络的实现_第1张图片
pytorch中CNN网络的实现_第2张图片

你可能感兴趣的:(机器学习)