Python torch.save() Examples

Example 1

def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.data[0]))

    # 1. Save the model every epoch
    torch.save(model.state_dict(), "mnist_model_{0:03d}.pwf".format(epoch)) 

Example 2

def train(model, db, args, bsz=32, eph=1, use_cuda=False):
    print("Training...")

    trainloader = data_utils.DataLoader(dataset=db, batch_size=bsz, shuffle=True)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    best_loss = 100000

    for epoch in range(eph):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 1):
            inputs, targets = data
            inputs = inputs.unsqueeze(1)
            targets = target_onehot_to_classnum_tensor(targets)
            if use_cuda and cuda_ava:
                inputs = Variable(inputs.float().cuda())
                targets = Variable(targets.cuda())
            else:
                inputs = Variable(inputs.float())
                targets = Variable(targets)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.data[0]
            last_loss = loss.data[0]
            if i % 100 == 0:
                print("[%d, %5d] loss: %.3f" % (epoch + 1, i, running_loss / 100))
                running_loss = 0

            if last_loss < best_loss:
                best_loss = last_loss
                acc = evaluate(model, trainloader, use_cuda)
                torch.save(model.state_dict(), os.path.join('saved_model', 'cnnT1_epoch_{}_iter_{}_loss_{}_acc_{}_{}.t7'.format(epoch + 1, i, last_loss, acc, datetime.datetime.now().strftime("%b_%d_%H:%M:%S"))))
    acc = evaluate(model, trainloader, use_cuda)
    torch.save(model.state_dict(), os.path.join('saved_model', 'cnnT1_all_acc_{}.t7'.format(acc)))

    print("Finished Training!") 

Example 3

def save_checkpoint(self, state, is_best):
        """
        Save a copy of the model so that it can be loaded at a future
        date. This function is used when the model is being evaluated 
        on the test data.

        Furthermore, the model with the highest accuracy is saved as
        with a special name.
        """
        print("[*] Saving model to {}".format(self.ckpt_dir))

        filename = self.get_model_name() + '_ckpt.pth.tar'
        ckpt_path = os.path.join(self.ckpt_dir, filename)
        torch.save(state, ckpt_path)

        if is_best:
            filename = self.get_model_name() + '_model_best.pth.tar'
            shutil.copyfile(ckpt_path, 
                os.path.join(self.ckpt_dir, filename))
            print("[*] ==== Best Valid Acc Achieved ====") 

Example 4

def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'):
        if self._state('save_model_path') is not None:
            filename_ = filename
            filename = os.path.join(self.state['save_model_path'], filename_)
            if not os.path.exists(self.state['save_model_path']):
                os.makedirs(self.state['save_model_path'])
        print('save model {filename}'.format(filename=filename))
        torch.save(state, filename)
        if is_best:
            filename_best = 'model_best.pth.tar'
            if self._state('save_model_path') is not None:
                filename_best = os.path.join(self.state['save_model_path'], filename_best)
            shutil.copyfile(filename, filename_best)
            if self._state('save_model_path') is not None:
                if self._state('filename_previous_best') is not None:
                    os.remove(self._state('filename_previous_best'))
                filename_best = os.path.join(self.state['save_model_path'], 'model_best_{score:.4f}.pth.tar'.format(score=state['best_score']))
                shutil.copyfile(filename, filename_best)
                self.state['filename_previous_best'] = filename_best 

你可能感兴趣的:(Python torch.save() Examples)