pytorch&tf等 深度学习实验管理工具(Sacred)

Sacred 是一个 Python 库, 可以帮助研究人员配置、组织、记录和复制实验。
官方文档
official github

1. 简单介绍

pytorch&tf等 深度学习实验管理工具(Sacred)_第1张图片
上图来自该博客
从大佬的实验来看,scared是一个可以在任意框架下使用的python工具。其在参数管理方面非常优秀,但是前端显示比较弱势。

2. 案例使用

pytorch&tf等 深度学习实验管理工具(Sacred)_第2张图片
如上图,sacred的使用非常简单,还有很多记录实验的特性见官方文档。安装只需(或者手动去github下载文件):

pip install sacred

以下代码是一个较完整的训练代码

from sacred import Experiment
from sacred.observers import MongoObserver
from sacred.utils import apply_backspaces_and_linefeeds
from sacred.observers import FileStorageObserver

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms


ex = Experiment("mnist_cnn")
#ex.observers.append(MongoObserver.create(url='localhost:27017', db_name='sacred'))
# 这里使用了数据库(可以不采用,采用本地文件记录FileStorageObserver),如下: 
ex.observers.append(FileStorageObserver('my_exp'))
ex.captured_out_filter = apply_backspaces_and_linefeeds # 过滤非标准输出(tqdm)


# 超参数设置
@ex.config
def myconfig():
    # Device configuration
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    # Hyper parameters
    num_epochs = 5
    num_classes = 10
    batch_size = 100
    learning_rate = 0.001


# Convolutional neural network (你的网络)
class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.fc = nn.Linear(7 * 7 * 32, num_classes)

    def forward(self, x):
       
        out = self.fc(out)
        return out


# 自动解析myconfig()参数,直接导入并运行
@ex.automain
def main(_run,device,num_epochs,num_classes,batch_size,learning_rate):
    # MNIST dataset
    train_dataset = torchvision.datasets.MNIST(root='/home/ubuntu/Datasets/MINIST/',
                                               train=True,
                                               transform=transforms.ToTensor(),
                                               download=True)

    test_dataset = torchvision.datasets.MNIST(root='/home/ubuntu/Datasets/MINIST/',
                                              train=False,
                                              transform=transforms.ToTensor())

    # Data loader
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)

    model = ConvNet(num_classes).to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    total_step = len(train_loader)
    for epoch in range(num_epochs):
        for i, (images, labels) in enumerate(train_loader):
            images = images.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            ## _run.log_scalar 加入metric信息
			_run.log_scalar('training.loss', loss, epoch)
    		
            if (i + 1) % 100 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                      .format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))

    # Test the model
    model.eval()  # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        # 加入测试信息
		_run.log_scalar('test.correct', (100 * correct / total))
        print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

    # Save the model checkpoint
    torch.save(model.state_dict(), 'model.ckpt')

3. 结果展示

pytorch&tf等 深度学习实验管理工具(Sacred)_第3张图片
如果配置了数据库和前端可视化工具 MongoDB + Omniboard(others),有更好的体验。以下是两篇安装和介绍博客:
博客1
博客2
pytorch&tf等 深度学习实验管理工具(Sacred)_第4张图片
上图来自博客2。

你可能感兴趣的:(精简记录,深度学习,文章记录)