相信大家对FizzBuzz问题都不陌生。
给定整数 n n n,按照如下规则打印从 s t a r t start start到 s t a r t + n start+n start+n的每个数:
fizz
;buzz
;fizzbuzz
;样例:
s t a r t = 1 , n = 14 start=1,n=14 start=1,n=14,输出["1", "2", "fizz", "4", "buzz", "fizz", "7", "8", "fizz", "buzz", "11", "fizz", "13", "14", "fizzbuzz"]
。
现在要求使用神经网络实现FizzBuzz问题。
FizzBuzz问题本质上是一个四分类问题,即输入一个数字,我们需要将其分为数字本身、Fizz、Buzz、FizzBuzz其中的一类。我们可以搭建一个神经网络,其输入层、隐层、输出层均为全连接层,借助它完成分类任务,进而解决问题。
下面使用PyTorch搭建一个具有两个隐层的神经网络。我们可以将输入的数字转换为二进制数,将其每一位上的0或1作为神经网络的输入。神经网络的输入层为20,即这个神经网络可以完成对 [ 1 , 2 20 ) [1,2^{20}) [1,220)范围内的数字的分类;输出层为4,即我们需要将输入的数字分为4类;激活函数使用ReLU。
import torch
import torch.nn as nn
class net(nn.Module):
def __init__(self):
super(net, self).__init__()
self.fc1 = nn.Sequential(
nn.Linear(20, 256),
nn.ReLU(inplace=True)
)
self.fc2 = nn.Sequential(
nn.Linear(256, 128),
nn.ReLU(inplace=True)
)
self.fc3 = nn.Linear(128, 4)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def fizzbuzz_encode(i):
if i % 15 == 0:
return 3
elif i % 5 == 0:
return 2
elif i % 3 == 0:
return 1
else:
return 0
def fizzbuzz_decode(i, pred):
return [str(i), 'fizz', 'buzz', 'fizzbuzz'][pred]
def bin_encode(i, bit):
return [int(i) for i in list(('{:0>20s}'.format(bin(i)[2:])))]
接下来对模型进行训练。
关于训练集和验证集的构建,首先我构造了一个列表nums
,其中存放范围在 [ 1 , 2 20 ) [1, 2^{20}) [1,220)的所有整数。我们可以利用sklearn.model_selection.train_test_split()
,设置train_size=0.8
,对所有数据以4:1的比例划分为训练集和验证集,得到train_X
和eval_X
,进而构造train_y
和eval_y
。
由于我们即将进行的是一个多分类任务,所以损失函数选择交叉熵损失函数torch.nn.CrossEntropyLoss()
。优化器选择自适应矩估计 (Adaptive moment estimation, Adam) torch.optim.Adam()
,设置学习率lr=5e-4
,动量betas=(0.9, 0.999)
。设置训练轮数为100,批大小为128。
下面是训练过程的代码。
import os
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from fizzbuzz import fizzbuzz_encode, bin_encode
from model import net
if not os.path.exists('./model/'):
os.mkdir('./model/')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('[Fizz-Buzz is being trained using the {}]'.format('GPU' if torch.cuda.is_available() else 'CPU'))
nums = [i for i in range(1, 2 ** 20)]
train_X, eval_X = train_test_split(nums, train_size=0.8)
print('Loading the training set...')
train_size = len(train_X)
train_y = torch.LongTensor([fizzbuzz_encode(i) for i in train_X])
train_X = torch.Tensor(np.array([bin_encode(i, 20) for i in train_X]))
print('Over. The size of the training set is {}.'.format(train_size))
print('Loading the eval set...')
eval_size = len(eval_X)
eval_y = torch.LongTensor([fizzbuzz_encode(i) for i in eval_X])
eval_X = torch.Tensor(np.array([bin_encode(i, 20) for i in eval_X]))
print('Over. The size of the eval set is {}.'.format(eval_size))
model = net().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
EPOCH = 100
BATCH_SIZE = 128
print('Training...')
for epoch in range(1, EPOCH + 1):
# Train
train_loss = 0
train_acc = 0
train_cnt = 0
model = model.train()
for step in range(0, len(train_X), BATCH_SIZE):
batch_X = train_X[step:step + BATCH_SIZE].to(device)
batch_y = train_y[step:step + BATCH_SIZE].to(device)
out = model(batch_X)
loss = criterion(out, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.item()
_, pred = out.max(1)
num_correct = (pred == batch_y).sum().item()
acc = num_correct / BATCH_SIZE
train_acc += acc
train_cnt += 1
# Eval
eval_loss = 0
eval_acc = 0
eval_cnt = 0
model = model.eval()
for step in range(0, len(eval_X), BATCH_SIZE):
batch_X = eval_X[step:step + BATCH_SIZE].to(device)
batch_y = eval_y[step:step + BATCH_SIZE].to(device)
out = model(batch_X)
loss = criterion(out, batch_y)
eval_loss += loss.item()
_, pred = out.max(1)
num_correct = (pred == batch_y).sum().item()
acc = num_correct / BATCH_SIZE
eval_acc += acc
eval_cnt += 1
print('[{:3d}/100] Train Loss: {:11.9f} | Train Accuracy: {:6.4f} | Eval Loss: {:11.9f} | Eval Accuracy: {:6.4f}'
.format(epoch, train_loss / train_cnt, train_acc / train_cnt, eval_loss / eval_cnt, eval_acc / eval_cnt))
# Save model
torch.save(model.state_dict(), './model/fizzbuzz_epoch{}.pth'.format(epoch))
print('Training completed.')
训练结果如下所示:
[Fizz-Buzz is being trained using the GPU]
Loading the training set...
Over. The size of the training set is 838860.
Loading the eval set...
Over. The size of the eval set is 209715.
Training...
[ 1/100] Train Loss: 1.102775070 | Train Accuracy: 0.5531 | Eval Loss: 0.928931041 | Eval Accuracy: 0.6458
[ 2/100] Train Loss: 0.647325612 | Train Accuracy: 0.7525 | Eval Loss: 0.466397490 | Eval Accuracy: 0.8216
[ 3/100] Train Loss: 0.381442493 | Train Accuracy: 0.8596 | Eval Loss: 0.324849255 | Eval Accuracy: 0.8827
[ 4/100] Train Loss: 0.288911639 | Train Accuracy: 0.8829 | Eval Loss: 0.251872926 | Eval Accuracy: 0.8999
[ 5/100] Train Loss: 0.205670430 | Train Accuracy: 0.9244 | Eval Loss: 0.177700364 | Eval Accuracy: 0.9481
[ 6/100] Train Loss: 0.139342964 | Train Accuracy: 0.9565 | Eval Loss: 0.115904677 | Eval Accuracy: 0.9734
[ 7/100] Train Loss: 0.097691057 | Train Accuracy: 0.9717 | Eval Loss: 0.079528895 | Eval Accuracy: 0.9813
[ 8/100] Train Loss: 0.072108123 | Train Accuracy: 0.9791 | Eval Loss: 0.057199334 | Eval Accuracy: 0.9862
[ 9/100] Train Loss: 0.052352456 | Train Accuracy: 0.9853 | Eval Loss: 0.039544852 | Eval Accuracy: 0.9902
[ 10/100] Train Loss: 0.039294782 | Train Accuracy: 0.9893 | Eval Loss: 0.029183812 | Eval Accuracy: 0.9922
[ 11/100] Train Loss: 0.030768094 | Train Accuracy: 0.9917 | Eval Loss: 0.024201840 | Eval Accuracy: 0.9930
[ 12/100] Train Loss: 0.024932299 | Train Accuracy: 0.9934 | Eval Loss: 0.020896540 | Eval Accuracy: 0.9934
[ 13/100] Train Loss: 0.020954109 | Train Accuracy: 0.9945 | Eval Loss: 0.018185570 | Eval Accuracy: 0.9941
[ 14/100] Train Loss: 0.018094144 | Train Accuracy: 0.9952 | Eval Loss: 0.016723786 | Eval Accuracy: 0.9942
[ 15/100] Train Loss: 0.015818256 | Train Accuracy: 0.9958 | Eval Loss: 0.012055101 | Eval Accuracy: 0.9960
[ 16/100] Train Loss: 0.014259863 | Train Accuracy: 0.9961 | Eval Loss: 0.012222336 | Eval Accuracy: 0.9961
[ 17/100] Train Loss: 0.012945026 | Train Accuracy: 0.9964 | Eval Loss: 0.011727666 | Eval Accuracy: 0.9958
[ 18/100] Train Loss: 0.011773113 | Train Accuracy: 0.9966 | Eval Loss: 0.020707660 | Eval Accuracy: 0.9916
[ 19/100] Train Loss: 0.010908648 | Train Accuracy: 0.9968 | Eval Loss: 0.015347923 | Eval Accuracy: 0.9941
[ 20/100] Train Loss: 0.010148572 | Train Accuracy: 0.9969 | Eval Loss: 0.010991387 | Eval Accuracy: 0.9960
[ 21/100] Train Loss: 0.009624445 | Train Accuracy: 0.9972 | Eval Loss: 0.011484019 | Eval Accuracy: 0.9959
[ 22/100] Train Loss: 0.008491344 | Train Accuracy: 0.9974 | Eval Loss: 0.009262721 | Eval Accuracy: 0.9968
[ 23/100] Train Loss: 0.008251481 | Train Accuracy: 0.9975 | Eval Loss: 0.005522054 | Eval Accuracy: 0.9981
[ 24/100] Train Loss: 0.008246436 | Train Accuracy: 0.9975 | Eval Loss: 0.007955568 | Eval Accuracy: 0.9971
[ 25/100] Train Loss: 0.007407112 | Train Accuracy: 0.9977 | Eval Loss: 0.007498477 | Eval Accuracy: 0.9973
[ 26/100] Train Loss: 0.007269602 | Train Accuracy: 0.9978 | Eval Loss: 0.006716619 | Eval Accuracy: 0.9972
[ 27/100] Train Loss: 0.007019400 | Train Accuracy: 0.9978 | Eval Loss: 0.014589484 | Eval Accuracy: 0.9941
[ 28/100] Train Loss: 0.006642517 | Train Accuracy: 0.9979 | Eval Loss: 0.003745178 | Eval Accuracy: 0.9987
[ 29/100] Train Loss: 0.006703015 | Train Accuracy: 0.9979 | Eval Loss: 0.009273426 | Eval Accuracy: 0.9962
[ 30/100] Train Loss: 0.006389014 | Train Accuracy: 0.9980 | Eval Loss: 0.002969400 | Eval Accuracy: 0.9989
[ 31/100] Train Loss: 0.006239470 | Train Accuracy: 0.9980 | Eval Loss: 0.004633031 | Eval Accuracy: 0.9982
[ 32/100] Train Loss: 0.006201456 | Train Accuracy: 0.9980 | Eval Loss: 0.004878159 | Eval Accuracy: 0.9982
[ 33/100] Train Loss: 0.005793620 | Train Accuracy: 0.9982 | Eval Loss: 0.003159688 | Eval Accuracy: 0.9987
[ 34/100] Train Loss: 0.005702128 | Train Accuracy: 0.9981 | Eval Loss: 0.008034307 | Eval Accuracy: 0.9970
[ 35/100] Train Loss: 0.005605559 | Train Accuracy: 0.9982 | Eval Loss: 0.002718459 | Eval Accuracy: 0.9989
[ 36/100] Train Loss: 0.005403034 | Train Accuracy: 0.9983 | Eval Loss: 0.004402274 | Eval Accuracy: 0.9982
[ 37/100] Train Loss: 0.004914636 | Train Accuracy: 0.9983 | Eval Loss: 0.002392252 | Eval Accuracy: 0.9990
[ 38/100] Train Loss: 0.005260276 | Train Accuracy: 0.9983 | Eval Loss: 0.002264460 | Eval Accuracy: 0.9991
[ 39/100] Train Loss: 0.005505662 | Train Accuracy: 0.9982 | Eval Loss: 0.002417869 | Eval Accuracy: 0.9990
[ 40/100] Train Loss: 0.004814269 | Train Accuracy: 0.9984 | Eval Loss: 0.006638321 | Eval Accuracy: 0.9974
[ 41/100] Train Loss: 0.005267406 | Train Accuracy: 0.9983 | Eval Loss: 0.003875148 | Eval Accuracy: 0.9984
[ 42/100] Train Loss: 0.004856745 | Train Accuracy: 0.9984 | Eval Loss: 0.004368525 | Eval Accuracy: 0.9982
[ 43/100] Train Loss: 0.004577446 | Train Accuracy: 0.9985 | Eval Loss: 0.002692292 | Eval Accuracy: 0.9989
[ 44/100] Train Loss: 0.004956163 | Train Accuracy: 0.9983 | Eval Loss: 0.004707843 | Eval Accuracy: 0.9980
[ 45/100] Train Loss: 0.004755975 | Train Accuracy: 0.9984 | Eval Loss: 0.001884688 | Eval Accuracy: 0.9992
[ 46/100] Train Loss: 0.003918191 | Train Accuracy: 0.9987 | Eval Loss: 0.003383207 | Eval Accuracy: 0.9986
[ 47/100] Train Loss: 0.004527767 | Train Accuracy: 0.9985 | Eval Loss: 0.006468480 | Eval Accuracy: 0.9974
[ 48/100] Train Loss: 0.004405924 | Train Accuracy: 0.9986 | Eval Loss: 0.005107094 | Eval Accuracy: 0.9978
[ 49/100] Train Loss: 0.004278097 | Train Accuracy: 0.9986 | Eval Loss: 0.003901212 | Eval Accuracy: 0.9984
[ 50/100] Train Loss: 0.004224002 | Train Accuracy: 0.9986 | Eval Loss: 0.005627046 | Eval Accuracy: 0.9977
[ 51/100] Train Loss: 0.004010710 | Train Accuracy: 0.9987 | Eval Loss: 0.003938096 | Eval Accuracy: 0.9982
[ 52/100] Train Loss: 0.004115922 | Train Accuracy: 0.9986 | Eval Loss: 0.002245946 | Eval Accuracy: 0.9990
[ 53/100] Train Loss: 0.003951488 | Train Accuracy: 0.9986 | Eval Loss: 0.004402691 | Eval Accuracy: 0.9982
[ 54/100] Train Loss: 0.003650514 | Train Accuracy: 0.9987 | Eval Loss: 0.003027403 | Eval Accuracy: 0.9986
[ 55/100] Train Loss: 0.004247091 | Train Accuracy: 0.9985 | Eval Loss: 0.010029032 | Eval Accuracy: 0.9963
[ 56/100] Train Loss: 0.003386153 | Train Accuracy: 0.9989 | Eval Loss: 0.001097557 | Eval Accuracy: 0.9994
[ 57/100] Train Loss: 0.003713696 | Train Accuracy: 0.9988 | Eval Loss: 0.004928977 | Eval Accuracy: 0.9980
[ 58/100] Train Loss: 0.003816056 | Train Accuracy: 0.9987 | Eval Loss: 0.004625882 | Eval Accuracy: 0.9981
[ 59/100] Train Loss: 0.003549325 | Train Accuracy: 0.9988 | Eval Loss: 0.001582666 | Eval Accuracy: 0.9992
[ 60/100] Train Loss: 0.003565447 | Train Accuracy: 0.9988 | Eval Loss: 0.003165373 | Eval Accuracy: 0.9986
[ 61/100] Train Loss: 0.003243083 | Train Accuracy: 0.9989 | Eval Loss: 0.000977819 | Eval Accuracy: 0.9994
[ 62/100] Train Loss: 0.003249073 | Train Accuracy: 0.9989 | Eval Loss: 0.007692643 | Eval Accuracy: 0.9971
[ 63/100] Train Loss: 0.003360245 | Train Accuracy: 0.9989 | Eval Loss: 0.003523805 | Eval Accuracy: 0.9985
[ 64/100] Train Loss: 0.003304055 | Train Accuracy: 0.9989 | Eval Loss: 0.006064202 | Eval Accuracy: 0.9976
[ 65/100] Train Loss: 0.003078253 | Train Accuracy: 0.9990 | Eval Loss: 0.000786970 | Eval Accuracy: 0.9995
[ 66/100] Train Loss: 0.003281826 | Train Accuracy: 0.9989 | Eval Loss: 0.000984443 | Eval Accuracy: 0.9994
[ 67/100] Train Loss: 0.002895257 | Train Accuracy: 0.9990 | Eval Loss: 0.002222579 | Eval Accuracy: 0.9990
[ 68/100] Train Loss: 0.003152595 | Train Accuracy: 0.9989 | Eval Loss: 0.003135084 | Eval Accuracy: 0.9987
[ 69/100] Train Loss: 0.003002117 | Train Accuracy: 0.9989 | Eval Loss: 0.000713773 | Eval Accuracy: 0.9995
[ 70/100] Train Loss: 0.002931449 | Train Accuracy: 0.9990 | Eval Loss: 0.001333298 | Eval Accuracy: 0.9993
[ 71/100] Train Loss: 0.002929780 | Train Accuracy: 0.9990 | Eval Loss: 0.006103724 | Eval Accuracy: 0.9977
[ 72/100] Train Loss: 0.003025964 | Train Accuracy: 0.9989 | Eval Loss: 0.002391597 | Eval Accuracy: 0.9990
[ 73/100] Train Loss: 0.002613716 | Train Accuracy: 0.9991 | Eval Loss: 0.012521741 | Eval Accuracy: 0.9951
[ 74/100] Train Loss: 0.002859766 | Train Accuracy: 0.9990 | Eval Loss: 0.005742355 | Eval Accuracy: 0.9979
[ 75/100] Train Loss: 0.002824380 | Train Accuracy: 0.9990 | Eval Loss: 0.001468052 | Eval Accuracy: 0.9992
[ 76/100] Train Loss: 0.002657217 | Train Accuracy: 0.9991 | Eval Loss: 0.000719256 | Eval Accuracy: 0.9995
[ 77/100] Train Loss: 0.002720614 | Train Accuracy: 0.9991 | Eval Loss: 0.000864803 | Eval Accuracy: 0.9994
[ 78/100] Train Loss: 0.002478095 | Train Accuracy: 0.9991 | Eval Loss: 0.001175038 | Eval Accuracy: 0.9993
[ 79/100] Train Loss: 0.002591830 | Train Accuracy: 0.9991 | Eval Loss: 0.006251177 | Eval Accuracy: 0.9976
[ 80/100] Train Loss: 0.003015996 | Train Accuracy: 0.9989 | Eval Loss: 0.000658074 | Eval Accuracy: 0.9995
[ 81/100] Train Loss: 0.002414642 | Train Accuracy: 0.9991 | Eval Loss: 0.003294699 | Eval Accuracy: 0.9985
[ 82/100] Train Loss: 0.002576179 | Train Accuracy: 0.9991 | Eval Loss: 0.003913511 | Eval Accuracy: 0.9983
[ 83/100] Train Loss: 0.002545077 | Train Accuracy: 0.9991 | Eval Loss: 0.001200617 | Eval Accuracy: 0.9993
[ 84/100] Train Loss: 0.002407082 | Train Accuracy: 0.9992 | Eval Loss: 0.001316432 | Eval Accuracy: 0.9993
[ 85/100] Train Loss: 0.002210199 | Train Accuracy: 0.9993 | Eval Loss: 0.004904620 | Eval Accuracy: 0.9980
[ 86/100] Train Loss: 0.002598349 | Train Accuracy: 0.9991 | Eval Loss: 0.004028998 | Eval Accuracy: 0.9983
[ 87/100] Train Loss: 0.002180286 | Train Accuracy: 0.9992 | Eval Loss: 0.000772777 | Eval Accuracy: 0.9994
[ 88/100] Train Loss: 0.002247557 | Train Accuracy: 0.9992 | Eval Loss: 0.001983232 | Eval Accuracy: 0.9990
[ 89/100] Train Loss: 0.002460223 | Train Accuracy: 0.9992 | Eval Loss: 0.001086888 | Eval Accuracy: 0.9994
[ 90/100] Train Loss: 0.002043483 | Train Accuracy: 0.9993 | Eval Loss: 0.003071309 | Eval Accuracy: 0.9988
[ 91/100] Train Loss: 0.002076443 | Train Accuracy: 0.9993 | Eval Loss: 0.000947966 | Eval Accuracy: 0.9994
[ 92/100] Train Loss: 0.002625130 | Train Accuracy: 0.9991 | Eval Loss: 0.005555006 | Eval Accuracy: 0.9979
[ 93/100] Train Loss: 0.001995683 | Train Accuracy: 0.9993 | Eval Loss: 0.002265547 | Eval Accuracy: 0.9989
[ 94/100] Train Loss: 0.002201190 | Train Accuracy: 0.9992 | Eval Loss: 0.001361133 | Eval Accuracy: 0.9992
[ 95/100] Train Loss: 0.002136063 | Train Accuracy: 0.9993 | Eval Loss: 0.000965736 | Eval Accuracy: 0.9993
[ 96/100] Train Loss: 0.002415759 | Train Accuracy: 0.9992 | Eval Loss: 0.004160137 | Eval Accuracy: 0.9985
[ 97/100] Train Loss: 0.001967646 | Train Accuracy: 0.9994 | Eval Loss: 0.001117895 | Eval Accuracy: 0.9993
[ 98/100] Train Loss: 0.002086239 | Train Accuracy: 0.9993 | Eval Loss: 0.007031670 | Eval Accuracy: 0.9972
[ 99/100] Train Loss: 0.002194939 | Train Accuracy: 0.9992 | Eval Loss: 0.001470774 | Eval Accuracy: 0.9992
[100/100] Train Loss: 0.001980887 | Train Accuracy: 0.9993 | Eval Loss: 0.000567319 | Eval Accuracy: 0.9995
Training completed.
通过训练过程中打印的信息我们可以看出,对于 [ 1 , 2 20 ) [1, 2^{20}) [1,220)这个范围内的数字,模型的泛化精度最高可以达到0.9995,整体表现还是相当出色的。我们可以选取第100轮的训练模型进行接下来的测试。
我们选取约 1 10 \frac{1}{10} 101的范围在 [ 1 , 2 20 ) [1, 2^{20}) [1,220)的整数作为测试集,使用第100轮的模型参数进行模型测试。
import random
import numpy as np
import torch
import torch.nn as nn
from fizzbuzz import fizzbuzz_encode, fizzbuzz_decode, bin_encode
from model import net
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = net().to(device)
model.load_state_dict(torch.load('./model/fizzbuzz_epoch100.pth', map_location=device))
test_data = [i for i in range(1, 2 ** 20) if not random.randint(0, 99)]
real = [fizzbuzz_decode(i, fizzbuzz_encode(i)) for i in test_data]
test_X = torch.Tensor(np.array([bin_encode(i, 20) for i in test_data]))
model = model.eval()
test_X = test_X.to(device)
out = model(test_X)
pred = list(out.max(1)[1].cpu().numpy())
pred = [fizzbuzz_decode(test_data[i], pred[i]) for i in range(len(pred))]
num_correct = 0
for i in range(len(real)):
if real[i] == pred[i]:
num_correct += 1
acc = num_correct / len(real)
print('Real results :', real)
print('Predicted results:', pred)
print('Accuracy : {:4.2f}% ({}/{})'.format(acc * 100, num_correct, len(real)))
测试结果如下(由于测试数据过多,此处不一一展示):
Real results : ['52', '79', 'buzz', '122', '172', 'fizz', '227', 'buzz', ...]
Predicted results: ['52', '79', 'buzz', '122', '172', 'fizz', '227', 'buzz', ...]
Accuracy : 100.00% (10567/10567)
可见,我们的模型将随机选取的范围内的整数全部正确分类。