李宏毅HW01——新冠疫情数据的预测

目的:熟悉熟悉pytorch

导入数据

!gdown --id '1kLSW_-cW2Huj7bh84YTdimGBOJaODiOS' --output covid.train.csv
!gdown --id '1iiI5qROrAhZn-o4FPqsE97bMzDEFvIdg' --output covid.test.csv
/Users/missbei/miniforge3/envs/NLP_search/lib/python3.8/site-packages/gdown/cli.py:127: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.
  warnings.warn(
Downloading...
From: https://drive.google.com/uc?id=1kLSW_-cW2Huj7bh84YTdimGBOJaODiOS
To: /Users/missbei/miniforge3/envs/NLP_search/Bilibili/Hung-Yi Lee/covid.train.csv
100%|██████████████████████████████████████| 2.49M/2.49M [00:00<00:00, 38.8MB/s]
/Users/missbei/miniforge3/envs/NLP_search/lib/python3.8/site-packages/gdown/cli.py:127: FutureWarning: Option `--id` was deprecated in version 4.3.1 and will be removed in 5.0. You don't need to pass it anymore to use a file ID.
  warnings.warn(
Downloading...
From: https://drive.google.com/uc?id=1iiI5qROrAhZn-o4FPqsE97bMzDEFvIdg
To: /Users/missbei/miniforge3/envs/NLP_search/Bilibili/Hung-Yi Lee/covid.test.csv
100%|████████████████████████████████████████| 993k/993k [00:00<00:00, 15.1MB/s]

import Package

import math
import numpy as np
import pandas as pd
import os
import csv
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

Some prepared code, no need to understand

from tqdm import tqdm
def same_seed(seed): 
    '''Fixes random number generator seeds for reproducibility.'''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train_valid_split(data_set, valid_ratio, seed):
    '''Split provided training data into training set and validation set'''
    valid_set_size = int(valid_ratio * len(data_set)) 
    train_set_size = len(data_set) - valid_set_size
    train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))
    return np.array(train_set), np.array(valid_set)

def predict(test_loader, model, device):
    model.eval() # Set your model to evaluation mode.
    preds = []
    for x in tqdm(test_loader):
        x = x.to(device)                        
        with torch.no_grad():                   
            pred = model(x)                     
            preds.append(pred.detach().cpu())   
    preds = torch.cat(preds, dim=0).numpy()  
    return preds



Dataset

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SlVuXp5c-1661580334880)(attachment:%E6%88%AA%E5%B1%8F2022-08-26%20%E4%B8%8B%E5%8D%886.32.46.png)]

  • 第0列是id
  • 第1-37列是37个州的one-hot编码
  • 38-41是COVID-like illness
  • 42-49是Behavior Indicators
  • 50-52是Mental Health Indicators
  • 53是最后检测结果,阳不阳
  • 后续是第2天-第5天的结果,都写成列
class COVID19Dataset(Dataset):
    '''
    if no y, the predict
    '''
    def __init__(self, x, y = None):
        if y is None:
            self.y = y
        else:
            self.y = torch.FloatTensor(y)
        self.x = torch.FloatTensor(x)
        
    def __getitem__(self, index):
        if self.y is None:
            return self.x[index]
        else:
            return self.x[index], self.y[index]
        
    def __len__(self):
        return len(self.x)

Neural Network Model

class NN_Model(nn.Module):
    def __init__(self, input_dim):
        super().__init__()              # 搞定多个继承关系的
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )
    
    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1)                # (B, 1) ->(B)
        return x

Featurn Selection

def select_feat(train_data, valid_data, test_data, select_all = True):
    '''Selects useful features to perform regression'''
    y_train, y_valid = train_data[:, -1], valid_data[:, -1]
    raw_x_train, raw_x_valid, raw_x_test = train_data[:, :-1], valid_data[:, :-1], test_data
    
    if select_all:
        feat_idx = list(range(raw_x_train.shape[1]))
    # else: #根据传入参数自行筛选
    
    return raw_x_train[:, feat_idx], raw_x_valid[:, feat_idx], raw_x_test[:, feat_idx], y_train, y_valid

Configurations

device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {
    'seed' : 5201314,
    'select_all' : True,
    'valid_ratio' : 0.2,
    'n_epochs' : 3000,
    'batch_size': 256,
    'learning_rate' : 1e-5,
    'early_stop' : 400, #如果模型连续400个epoch没进步,那就停
    'save_path' : './models/model.ckpt' # save in hear       问题1:??没有s吗??
}

Dataloader

# 让种子全部复用
same_seed(config['seed'])


# train_data.shape = (2699, 118)     id + 37states + 16feats*5days
# test_data.shape = (1078, 117)      without outcome case
train_data, test_data = pd.read_csv('./covid.train.csv').values, pd.read_csv('./covid.test.csv').values

# def train_valid_split(data_set, valid_ratio, seed)
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])

# 打印一下data size
print(f'''train_data size: {train_data.shape}
valid_data size: {valid_data.shape}
test_data size: {test_data.shape}''')

# select features    def select_feat(train_data, valid_data, test_data, select_all = True)
x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'])

# 打印选择的featurn数量
print(f'''number of features: {x_train.shape[1]}''')

# 制作数据集
train_dataset = COVID19Dataset(x_train, y_train)
valid_dataset = COVID19Dataset(x_valid, y_valid)
test_dataset = COVID19Dataset(x_test)

# 使用dataloader
train_loader = DataLoader(train_dataset, batch_size = config['batch_size'], shuffle = True) # 问题2: pin——memory是什么?
valid_loader = DataLoader(valid_dataset, batch_size = config['batch_size'], shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = config['batch_size'], shuffle = True)
train_data size: (2160, 118)
valid_data size: (539, 118)
test_data size: (1078, 117)
number of features: 117
DataLoader??

Training Loop

def trainer(train_loader, valid_loader, model, config, device):
    # Define your optimization algorithm. 
    # TODO: Please check https://pytorch.org/docs/stable/optim.html to get more available algorithms.
    # TODO: L2 regularization (optimizer(weight decay...) or implement by your self).
    
    criterion = nn.MSELoss(reduction = 'mean') # 损失函数
    optimizer = torch.optim.SGD(model.parameters(), lr = config['learning_rate'], momentum = 0.9)
    
    # 创建存放模型的目录,如果不存在的话
    if not os.path.isdir('./models'):
        os.mkdir('./models')
    
    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0
    
    for epoch in range(n_epochs):
        model.train()      #在这里设置成训练模式
        loss_record = []
        
        # tqdm 是用来可视化训练过程的
        train_pbar = tqdm(train_loader, position = 0, leave = True)      # 问题3: 不会tqdm这个package
        
        for x, y in train_pbar:
            optimizer.zero_grad()       # 把gradient设置为0
            x, y = x.to(device), y.to(device)
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()           # 计算gradient
            optimizer.step()          # 更新参数
            step += 1
            loss_record.append(loss.detach().item())
            
            # display current epoch
            train_pbar.set_description(f'Epoch[{epoch+1}/{n_epochs}]')
            train_pbar.set_postfix({'loss': loss.detach().item()})
            
        mean_train_loss = sum(loss_record) / len(loss_record)
        
        # 把模型变成预测模式
        model.eval()
        loss_record = []
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                pred = model(x)
                loss = criterion(pred, y)
            loss_record.append(loss)
        mean_valid_loss = sum(loss_record) / len(loss_record)
        if epoch%5 == 0:
            print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')
        
        if mean_valid_loss < best_loss:
            best_loss = mean_valid_loss
            # 存储模型语句
            torch.save(model.state_dict(), config['save_path'])
#             print('Saving model with loss {:.3f}...'.format(best_loss))
            early_stop_count = 0
        else:
            early_stop_count += 1
        
        if early_stop_count >= config['early_stop']:
            print('\nModel is not improving! Stop training session!')
            return
        
        
        
        

Start Training

# loader没有shape,用之前的pandas或tensor的shape
model = NN_Model(input_dim = x_train.shape[1]).to(device)
# def trainer(train_loader, valid_loader, model, config, device):
trainer(train_loader, valid_loader, model, config, device)

Epoch[1/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 250.23it/s, loss=44.6]


Epoch [1/3000]: Train loss: 120.6492, Valid loss: 62.5109


Epoch[2/3000]: 100%|████████████████████| 9/9 [00:00<00:00, 341.45it/s, loss=46]
Epoch[3/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 485.88it/s, loss=38.1]
Epoch[4/3000]: 100%|████████████████████| 9/9 [00:00<00:00, 554.18it/s, loss=33]
Epoch[5/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 480.26it/s, loss=33.1]
Epoch[6/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 446.64it/s, loss=27.9]


Epoch [6/3000]: Train loss: 35.3303, Valid loss: 38.8439


Epoch[7/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 533.33it/s, loss=30.6]
Epoch[8/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 517.69it/s, loss=32.2]
Epoch[9/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 514.35it/s, loss=26.6]
Epoch[10/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 554.39it/s, loss=29.7]
Epoch[11/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 544.46it/s, loss=30.1]


Epoch [11/3000]: Train loss: 30.9335, Valid loss: 28.3564


Epoch[12/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 534.93it/s, loss=22.5]
Epoch[13/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 521.05it/s, loss=26.4]
Epoch[14/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 570.46it/s, loss=27.1]
Epoch[15/3000]: 100%|███████████████████| 9/9 [00:00<00:00, 557.76it/s, loss=28]
Epoch[16/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 505.56it/s, loss=21.1]


Epoch [16/3000]: Train loss: 23.2565, Valid loss: 24.7602


Epoch[17/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 531.66it/s, loss=19.7]
Epoch[18/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 524.36it/s, loss=17.1]
Epoch[19/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 622.11it/s, loss=14.3]
Epoch[20/3000]: 100%|███████████████████| 9/9 [00:00<00:00, 580.43it/s, loss=12]
Epoch[21/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 525.15it/s, loss=20.9]


Epoch [21/3000]: Train loss: 14.5668, Valid loss: 37.8424


Epoch[22/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 105.44it/s, loss=18.7]
Epoch[23/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 598.17it/s, loss=21.5]
Epoch[24/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 517.55it/s, loss=11.8]
Epoch[25/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 523.55it/s, loss=14.4]
Epoch[26/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 581.62it/s, loss=11.7]


Epoch [26/3000]: Train loss: 10.2813, Valid loss: 7.9503


Epoch[27/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 336.25it/s, loss=12.9]
Epoch[28/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 528.10it/s, loss=26.7]
Epoch[29/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 509.48it/s, loss=31.5]
Epoch[30/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 624.43it/s, loss=54.2]
Epoch[31/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 591.26it/s, loss=45.3]


Epoch [31/3000]: Train loss: 35.1419, Valid loss: 30.8228


Epoch[32/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 513.39it/s, loss=19.3]
Epoch[33/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 532.32it/s, loss=18.1]
Epoch[34/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 543.74it/s, loss=19.5]
Epoch[35/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 586.43it/s, loss=11.9]
Epoch[36/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 568.03it/s, loss=9.05]


Epoch [36/3000]: Train loss: 13.0948, Valid loss: 13.3034


Epoch[37/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 500.28it/s, loss=7.06]
Epoch[38/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 528.27it/s, loss=7.84]
Epoch[39/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 540.97it/s, loss=7.28]
Epoch[40/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 606.72it/s, loss=17.6]
Epoch[41/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 607.36it/s, loss=7.56]


Epoch [41/3000]: Train loss: 15.5544, Valid loss: 36.4536


Epoch[42/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 475.48it/s, loss=27.3]
Epoch[43/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 510.92it/s, loss=15.3]
Epoch[44/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 521.66it/s, loss=8.41]
Epoch[45/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 605.12it/s, loss=7.48]
Epoch[46/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 552.45it/s, loss=8.69]


Epoch [46/3000]: Train loss: 9.7680, Valid loss: 11.5441


Epoch[47/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 394.27it/s, loss=8.36]
Epoch[48/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 525.98it/s, loss=11.9]
Epoch[49/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 523.81it/s, loss=7.33]
Epoch[50/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 578.78it/s, loss=8.07]
Epoch[51/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 520.15it/s, loss=5.72]


Epoch [51/3000]: Train loss: 6.5317, Valid loss: 5.2750


Epoch[52/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 555.56it/s, loss=7.72]
Epoch[53/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 522.29it/s, loss=4.57]
Epoch[54/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 482.49it/s, loss=3.73]
Epoch[55/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 568.43it/s, loss=5.9]
Epoch[56/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 514.58it/s, loss=11.8]


Epoch [56/3000]: Train loss: 8.1241, Valid loss: 5.8720


Epoch[57/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 491.16it/s, loss=4.36]
Epoch[58/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 511.47it/s, loss=4.52]
Epoch[59/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 522.31it/s, loss=8.92]
Epoch[60/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 590.86it/s, loss=10.6]
Epoch[61/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 485.31it/s, loss=9.16]


Epoch [61/3000]: Train loss: 9.0994, Valid loss: 7.3874


Epoch[62/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 493.52it/s, loss=7.16]
Epoch[63/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 475.24it/s, loss=9.74]
Epoch[64/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 566.42it/s, loss=5.54]
Epoch[65/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 524.84it/s, loss=5.91]
Epoch[66/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 501.70it/s, loss=7.36]


Epoch [66/3000]: Train loss: 6.1128, Valid loss: 5.4114


Epoch[67/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 501.89it/s, loss=8.69]
Epoch[68/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 494.15it/s, loss=6.44]
Epoch[69/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 552.08it/s, loss=5.46]
Epoch[70/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 331.66it/s, loss=7.9]
Epoch[71/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 565.40it/s, loss=5.18]


Epoch [71/3000]: Train loss: 6.4281, Valid loss: 4.6097


Epoch[72/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 571.40it/s, loss=6.44]
Epoch[73/3000]: 100%|███████████████████| 9/9 [00:00<00:00, 569.63it/s, loss=10]
Epoch[74/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 634.25it/s, loss=6.29]
Epoch[75/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 578.92it/s, loss=6.46]
Epoch[76/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 575.22it/s, loss=5.16]


Epoch [76/3000]: Train loss: 5.7318, Valid loss: 5.1168


Epoch[77/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 558.28it/s, loss=4.02]
Epoch[78/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 558.67it/s, loss=5.75]
Epoch[79/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 635.09it/s, loss=6.12]
Epoch[80/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 599.08it/s, loss=6.61]
Epoch[81/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 521.58it/s, loss=5.37]


Epoch [81/3000]: Train loss: 5.3825, Valid loss: 5.1466


Epoch[82/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 559.97it/s, loss=3.52]
Epoch[83/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 608.11it/s, loss=4.43]
Epoch[84/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 639.70it/s, loss=6.82]
Epoch[85/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 629.72it/s, loss=6.58]
Epoch[86/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 575.89it/s, loss=9.49]


Epoch [86/3000]: Train loss: 7.1615, Valid loss: 7.9000


Epoch[87/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 564.41it/s, loss=5.49]
Epoch[88/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 558.79it/s, loss=3.65]
Epoch[89/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 577.89it/s, loss=4.14]
Epoch[90/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 652.60it/s, loss=4.35]
Epoch[91/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 603.46it/s, loss=7.1]


Epoch [91/3000]: Train loss: 5.8941, Valid loss: 6.2084


Epoch[92/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 590.29it/s, loss=6.14]
Epoch[93/3000]: 100%|██████████████████| 9/9 [00:00<00:00, 589.35it/s, loss=5.4]
Epoch[94/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 590.98it/s, loss=4.46]
Epoch[95/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 644.04it/s, loss=5.57]
Epoch[96/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 691.13it/s, loss=3.64]


Epoch [96/3000]: Train loss: 5.3036, Valid loss: 6.6357


Epoch[97/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 567.65it/s, loss=8.56]
Epoch[98/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 589.28it/s, loss=6.68]
Epoch[99/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 580.15it/s, loss=5.91]
Epoch[100/3000]: 100%|████████████████| 9/9 [00:00<00:00, 584.26it/s, loss=5.26]
Epoch[101/3000]: 100%|████████████████| 9/9 [00:00<00:00, 619.83it/s, loss=5.58]


Epoch [101/3000]: Train loss: 5.0137, Valid loss: 6.6421


Epoch[102/3000]: 100%|████████████████| 9/9 [00:00<00:00, 583.32it/s, loss=5.95]
Epoch[103/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 572.91it/s, loss=8.5]
Epoch[104/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 596.31it/s, loss=5.6]
Epoch[105/3000]: 100%|████████████████| 9/9 [00:00<00:00, 592.66it/s, loss=4.34]
Epoch[106/3000]: 100%|████████████████| 9/9 [00:00<00:00, 114.23it/s, loss=5.92]


Epoch [106/3000]: Train loss: 5.3649, Valid loss: 6.1438


Epoch[107/3000]: 100%|████████████████| 9/9 [00:00<00:00, 589.86it/s, loss=6.57]
Epoch[108/3000]: 100%|████████████████| 9/9 [00:00<00:00, 638.35it/s, loss=5.45]
Epoch[109/3000]: 100%|████████████████| 9/9 [00:00<00:00, 654.87it/s, loss=5.07]
Epoch[110/3000]: 100%|████████████████| 9/9 [00:00<00:00, 548.77it/s, loss=4.64]
Epoch[111/3000]: 100%|████████████████| 9/9 [00:00<00:00, 570.70it/s, loss=5.48]


Epoch [111/3000]: Train loss: 4.9703, Valid loss: 5.3325


Epoch[112/3000]: 100%|████████████████| 9/9 [00:00<00:00, 546.88it/s, loss=5.25]
Epoch[113/3000]: 100%|████████████████| 9/9 [00:00<00:00, 580.29it/s, loss=4.97]
Epoch[114/3000]: 100%|████████████████| 9/9 [00:00<00:00, 638.89it/s, loss=5.27]
Epoch[115/3000]: 100%|████████████████| 9/9 [00:00<00:00, 561.86it/s, loss=5.25]
Epoch[116/3000]: 100%|████████████████| 9/9 [00:00<00:00, 611.25it/s, loss=4.32]


Epoch [116/3000]: Train loss: 4.8477, Valid loss: 5.1231


Epoch[117/3000]: 100%|████████████████| 9/9 [00:00<00:00, 571.83it/s, loss=5.24]
Epoch[118/3000]: 100%|████████████████| 9/9 [00:00<00:00, 577.41it/s, loss=5.16]
Epoch[119/3000]: 100%|████████████████| 9/9 [00:00<00:00, 635.28it/s, loss=6.45]
Epoch[120/3000]: 100%|████████████████| 9/9 [00:00<00:00, 597.40it/s, loss=3.46]
Epoch[121/3000]: 100%|████████████████| 9/9 [00:00<00:00, 583.35it/s, loss=5.23]


Epoch [121/3000]: Train loss: 4.7657, Valid loss: 5.7246


Epoch[122/3000]: 100%|████████████████| 9/9 [00:00<00:00, 584.64it/s, loss=4.46]
Epoch[123/3000]: 100%|████████████████| 9/9 [00:00<00:00, 564.09it/s, loss=5.94]
Epoch[124/3000]: 100%|████████████████| 9/9 [00:00<00:00, 604.14it/s, loss=4.37]
Epoch[125/3000]: 100%|████████████████| 9/9 [00:00<00:00, 653.42it/s, loss=4.03]
Epoch[126/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 508.39it/s, loss=4.2]


Epoch [126/3000]: Train loss: 4.7372, Valid loss: 5.1906


Epoch[127/3000]: 100%|████████████████| 9/9 [00:00<00:00, 558.98it/s, loss=6.33]
Epoch[128/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 510.53it/s, loss=8.5]
Epoch[129/3000]: 100%|████████████████| 9/9 [00:00<00:00, 568.93it/s, loss=12.5]
Epoch[130/3000]: 100%|████████████████| 9/9 [00:00<00:00, 647.90it/s, loss=6.99]
Epoch[131/3000]: 100%|████████████████| 9/9 [00:00<00:00, 559.49it/s, loss=9.89]


Epoch [131/3000]: Train loss: 5.8339, Valid loss: 5.4816


Epoch[132/3000]: 100%|████████████████| 9/9 [00:00<00:00, 572.34it/s, loss=5.14]
Epoch[133/3000]: 100%|████████████████| 9/9 [00:00<00:00, 595.11it/s, loss=5.67]
Epoch[134/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 606.60it/s, loss=5.5]
Epoch[135/3000]: 100%|████████████████| 9/9 [00:00<00:00, 670.29it/s, loss=4.95]
Epoch[136/3000]: 100%|█████████████████| 9/9 [00:00<00:00, 600.04it/s, loss=4.5]


Epoch [136/3000]: Train loss: 4.8287, Valid loss: 5.3120


Epoch[137/3000]: 100%|████████████████| 9/9 [00:00<00:00, 577.52it/s, loss=4.38]
Epoch[138/3000]: 100%|████████████████| 9/9 [00:00<00:00, 574.79it/s, loss=5.22]
Epoch[139/3000]: 100%|████████████████| 9/9 [00:00<00:00, 420.46it/s, loss=5.85]
Epoch[140/3000]: 100%|████████████████| 9/9 [00:00<00:00, 651.84it/s, loss=5.34]
Epoch[141/3000]: 100%|████████████████| 9/9 [00:00<00:00, 595.74it/s, loss=4.31]


Epoch [141/3000]: Train loss: 4.8080, Valid loss: 4.5696


Epoch[142/3000]:   0%|                         | 0/9 [00:00

Testing

def save_pred(preds, file):
    with open(file, 'w') as fp:
        writer = csv.writer(fp)
        writer.writerow(['id', 'tested_positive'])
        for i, p in enumerate(preds):
            writer.writerow([i, p])
model = NN_Model(input_dim = x_train.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))
preds = predict(test_loader, model, device)
save_pred(preds, 'pred.csv')
100%|████████████████████████████████████████████| 5/5 [00:00<00:00, 895.61it/s]

预测结果截图

李宏毅HW01——新冠疫情数据的预测_第1张图片

你可能感兴趣的:(深度学习,python,人工智能)