一个比较完整的pytorch项目

模型定义:

import torch
import torch.nn as nn

class DF(nn.Module):
    def __init__(self, nb_classes):
        super(DF, self).__init__()
        self.block1 = nn.Sequential(         
            nn.Conv1d(
                in_channels=1,              
                out_channels=32,            
                kernel_size=8,              
                stride=1,                   
                padding=0,                 
            ),  
            nn.BatchNorm1d(32),                   
            nn.ELU(alpha=1.0),                     
            nn.Conv1d(32, 32, 8, 1, 0),
            nn.BatchNorm1d(32),
            nn.ELU(alpha=1.0),
            nn.MaxPool1d(8, 4, 0), 
            nn.Dropout(0.1), 
        )

        self.block2 = nn.Sequential(
            nn.Conv1d(32, 64, 8, 1, 0),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Conv1d(64, 64, 8, 1, 0),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(8, 4, 0),
            nn.Dropout(0.1),
        )

        self.block3 = nn.Sequential(
            nn.Conv1d(64, 128, 8, 1, 0),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Conv1d(128, 128, 8, 1, 0),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(8, 4, 0),
            nn.Dropout(0.1),
        )

        self.block4 = nn.Sequential(
            nn.Conv1d(128, 256, 8, 1, 0),
            nn.BatchNorm1d(256), 
            nn.ReLU(),
            nn.Conv1d(256, 256, 8, 1, 0),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.MaxPool1d(8, 4, 0),
            nn.Dropout(0.1),
        )

        self.fc1 = nn.Sequential(         
            nn.Flatten(),
            nn.Linear(3328,512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.7),              
        )

        self.fc2 = nn.Sequential(
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),              
        )

        self.out = nn.Sequential(
            nn.Linear(512, nb_classes),
        )   

    def forward(self, x):
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.fc1(x)
        x = self.fc2(x)          
        output = self.out(x)
        return output, x   

模型训练:

train_loader=train_dl
NB_CLASSES = 50
EPOCH = 100
BATCH_SIZE = 128
LR = 0.001

cnn = DF(NB_CLASSES).float().cuda()

loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
validation_loader=test_dl
validation_size=450
train_size=4050

for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader):
        # 128 1 5000
        b_x = b_x.cuda()
        b_y = b_y.cuda()
        output = cnn(b_x.unsqueeze(-2).float())[0]
        loss = loss_func(output, b_y.squeeze().long())
        resnet_loss.append(loss)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 1 == 0:
            corrects = 0
            avg_loss = 0
            for _, (b_x, b_y) in enumerate(validation_loader):
                b_x = b_x.cuda()
                b_y = b_y.cuda()
                logit = cnn(b_x.unsqueeze(-2).float())[0]
                

                
                loss = loss_func(logit, b_y.squeeze().long())
                avg_loss += loss.item()
                corrects += (torch.max(logit, 1)
                            [1].view(b_y.size()).data == b_y.data).sum()
            
            size = validation_size
            avg_loss /= size
            accuracy = 100.0 * corrects / size
            
            resnet_accuracy.append(accuracy)
            
            print('Epoch: {:2d}({:6d}/{}) Evaluation - loss: {:.6f}  acc: {:3.4f}%({}/{})'.format(
                                                                            epoch,
                                                                            step * 128,
                                                                            train_size,
                                                                            avg_loss, 
                                                                            accuracy, 
                                                                            corrects, 
                                                                            size))

你可能感兴趣的:(pytorch,pytorch)