pytorch显存不足时的解决办法

  1. 将float32转化为float16,是最有效的降低显存占用的方式,可以降低一半左右的显存占用。
    实现方式:首先在代码的最前面加上
torch.set_default_dtype(torch.float16)

这行代码将这个程序内部所有的float变量转化为float32。
此时如果直接运行程序会出现输入为float参数为floathalf的错误报告。
然后将输入改变为floathalf即可
代码如下

 inputs = inputs.type(torch.float16)

方法1不建议采用,在后面使用cuDNN加速时候容易报


RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED
You can try to repro this exception using the following code snippet. If that doesn't trigger the error, please include your original repro script when reporting this issue.

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.allow_tf32 = True
data = torch.randn([4, 32, 119, 159], dtype=torch.half, device='cuda', requires_grad=True)
net = torch.nn.Conv2d(32, 64, kernel_size=[5, 5], padding=[0, 0], stride=[2, 2], dilation=[1, 1], groups=1)
net = net.cuda().half()
out = net(data)
out.backward(torch.randn_like(out))
torch.cuda.synchronize()

ConvolutionParams 
    data_type = CUDNN_DATA_HALF
    padding = [0, 0, 0]
    stride = [2, 2, 0]
    dilation = [1, 1, 0]
    groups = 1
    deterministic = false
    allow_tf32 = true
input: TensorDescriptor 000001E834DE5180
    type = CUDNN_DATA_HALF
    nbDims = 4
    dimA = 4, 32, 119, 159, 
    strideA = 605472, 18921, 159, 1, 
output: TensorDescriptor 000001E834DE3AC0
    type = CUDNN_DATA_HALF
    nbDims = 4
    dimA = 4, 64, 58, 78, 
    strideA = 289536, 4524, 78, 1, 
weight: FilterDescriptor 000001E8349A6610
    type = CUDNN_DATA_HALF
    tensor_format = CUDNN_TENSOR_NCHW
    nbDims = 4
    dimA = 64, 32, 5, 5, 
Pointer addresses: 
    input: 0000002363108000
    output: 00000023637A8800
    weight: 0000002305E01600
Additional pointer addresses: 
    grad_output: 00000023637A8800
    grad_input: 0000002363108000
Backward data algorithm: 1

的错误。
2. 利用pytorch的checkpoint特性,可以极大地降低显存的使用。
实现方式:
在densenet中,官方给出了densenet的checkpoint实现,地址densenet可以通过在densenet参数中填入 memory_efficient=True来实现显存的降低。
对于其它网络,则可以通过

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torchvision.datasets.cifar import CIFAR10
import numpy as np
from progressbar import progressbar


def conv_bn_relu(in_ch, out_ch, ker_sz, stride, pad):
    return nn.Sequential(nn.Conv2d(in_ch, out_ch, ker_sz, stride, pad, bias=False),
                         nn.BatchNorm2d(out_ch),
                         nn.ReLU())


class NetA(nn.Module):
    def __init__(self, use_checkpoint=False):
        super().__init__()
        self.use_checkpoint = use_checkpoint

        k = 2
        # 32x32
        self.layer1 = conv_bn_relu(3, 32*k, 3, 1, 1)
        self.layer2 = conv_bn_relu(32*k, 32*k, 3, 2, 1)
        # 16x16
        self.layer3 = conv_bn_relu(32*k, 64*k, 3, 1, 1)
        self.layer4 = conv_bn_relu(64*k, 64*k, 3, 2, 1)
        # 8x8
        self.layer5 = conv_bn_relu(64*k, 128*k, 3, 1, 1)
        self.layer6 = conv_bn_relu(128*k, 128*k, 3, 2, 1)
        # 4x4
        self.layer7 = conv_bn_relu(128*k, 256*k, 3, 1, 1)
        self.layer8 = conv_bn_relu(256*k, 256*k, 3, 2, 1)
        # 1x1
        self.layer9 = nn.Linear(256*k, 10)

    def seg0(self, y):
        y = self.layer1(y)
        return y

    def seg1(self, y):
        y = self.layer2(y)
        y = self.layer3(y)
        return y

    def seg2(self, y):
        y = self.layer4(y)
        y = self.layer5(y)
        return y

    def seg3(self, y):
        y = self.layer6(y)
        y = self.layer7(y)
        return y

    def seg4(self, y):
        y = self.layer8(y)
        y = F.adaptive_avg_pool2d(y, 1)
        y = torch.flatten(y, 1)
        y = self.layer9(y)
        return y

    def forward(self, x):
        y = x
        # y = self.layer1(y)
        y = y + torch.zeros(1, dtype=y.dtype, device=y.device, requires_grad=True)
        if self.use_checkpoint:
            # 使用 checkpoint
            y = checkpoint(self.seg0, y)
            y = checkpoint(self.seg1, y)
            y = checkpoint(self.seg2, y)
            y = checkpoint(self.seg3, y)
            y = checkpoint(self.seg4, y)
        else:
            # 不使用 checkpoint
            y = self.seg0(y)
            y = self.seg1(y)
            y = self.seg2(y)
            y = self.seg3(y)
            y = self.seg4(y)

        return y


if __name__ == '__main__':
    net = NetA(use_checkpoint=True).cuda()

    train_dataset = CIFAR10('../datasets/cifar10', True, download=True)
    train_x = np.asarray(train_dataset.data, np.uint8)
    train_y = np.asarray(train_dataset.targets, np.int)

    losser = nn.CrossEntropyLoss()
    optim = torch.optim.Adam(net.parameters(), 1e-3)

    epoch = 10
    batch_size = 31
    batch_count = int(np.ceil(len(train_x) / batch_size))

    for e_id in range(epoch):
        print('epoch', e_id)

        print('training')
        net.train()
        loss_sum = 0
        for b_id in progressbar(range(batch_count)):
            optim.zero_grad()

            batch_x = train_x[batch_size*b_id: batch_size*(b_id+1)]
            batch_y = train_y[batch_size*b_id: batch_size*(b_id+1)]

            batch_x =  torch.from_numpy(batch_x).permute(0, 3, 1, 2).float() / 255.
            batch_y =  torch.from_numpy(batch_y).long()

            batch_x = batch_x.cuda()
            batch_y = batch_y.cuda()

            batch_x = F.interpolate(batch_x, (224, 224), mode='bilinear')

            y = net(batch_x)
            loss = losser(y, batch_y)
            loss.backward()
            optim.step()
            loss_sum += loss.item()
        print('loss', loss_sum / batch_count)

        with torch.no_grad():
            print('testing')
            net.eval()
            acc_sum = 0
            for b_id in progressbar(range(batch_count)):
                optim.zero_grad()

                batch_x = train_x[batch_size * b_id: batch_size * (b_id + 1)]
                batch_y = train_y[batch_size * b_id: batch_size * (b_id + 1)]

                batch_x = torch.from_numpy(batch_x).permute(0, 3, 1, 2).float() / 255.
                batch_y = torch.from_numpy(batch_y).long()

                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()

                batch_x = F.interpolate(batch_x, (224, 224), mode='bilinear')

                y = net(batch_x)

                y = torch.topk(y, 1, dim=1).indices
                y = y[:, 0]

                acc = (y == batch_y).float().sum() / len(batch_x)

                acc_sum += acc.item()
            print('acc', acc_sum / batch_count)

        ids = np.arange(len(train_x))
        np.random.shuffle(ids)
        train_x = train_x[ids]
        train_y = train_y[ids]

这种方式添加。
但是在实际修改resnet的过程中,发现有时候修改了以后并不会减小显存,但是过一段时间之后又可以减少,且自己实现的减小显存的效率并没有官方给的高。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

class BasicBlock(nn.Module):
    """Basic Block for resnet 18 and resnet 34

    """

    #BasicBlock and BottleNeck block
    #have different output size
    #we use class attribute expansion
    #to distinct
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        #residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels * BasicBlock.expansion)
        )

        #shortcut
        self.shortcut = nn.Sequential()

        #the shortcut output dimension is not the same with residual function
        #use 1*1 convolution to match the dimension
        if stride != 1 or in_channels != BasicBlock.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * BasicBlock.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
#         return nn.ReLU(inplace=True)(checkpoint(self.residual_function,x) + checkpoint(self.shortcut,x))

class BottleNeck(nn.Module):
    """Residual block for resnet over 50 layers

    """
    expansion = 4
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(out_channels * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or in_channels != out_channels * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels * BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))
#          return nn.ReLU(inplace=True)(checkpoint(self.residual_function,x) + checkpoint(self.shortcut,x))

class ResNet(nn.Module):

    def __init__(self, block, num_block, num_classes=100):
        super().__init__()

        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        #we use a different inputsize than the original paper
        #so conv2_x's stride is 1
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
#         self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.classFc1 = nn.Linear(512 * block.expansion, 1)
  
                
        self.regFc1 = nn.Linear(512 * block.expansion, 7)
        

    def _make_layer(self, block, out_channels, num_blocks, stride):
        """make resnet layers(by layer i didnt mean this 'layer' was the
        same as a neuron netowork layer, ex. conv layer), one layer may
        contain more than one residual block

        Args:
            block: block type, basic block or bottle neck block
            out_channels: output depth channel number of this layer
            num_blocks: how many blocks per layer
            stride: the stride of the first block of this layer

        Return:
            return a resnet layer
        """

        # we have num_block blocks per layer, the first block
        # could be 1 or 2, other blocks would always be 1
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
#         output = self.conv1(x)
#         output = self.conv2_x(output)
#         output = self.conv3_x(output)
#         output = self.conv4_x(output)
#         output = self.conv5_x(output)
#         y = x
#         y = y + torch.zeros(1, dtype=y.dtype, device=y.device, requires_grad=True)
        output = checkpoint(self.conv1, x)
        output = checkpoint(self.conv2_x, output)
        output = checkpoint(self.conv3_x, output)
        output = checkpoint(self.conv4_x, output)
        output = checkpoint(self.conv5_x, output)
        output = self.avg_pool(output)
        output = output.view(output.size(0), -1)
#         output = self.fc(output)

#         return output
        x1 = self.classFc1(output)
#         x1 = F.sigmoid(x1)
        x1 = checkpoint(F.sigmoid, x1)
        
#         x2 = self.regFc1(output)
        x2 = checkpoint(self.regFc1, output)

        return x1, x2 
    
def resnet18():
    """ return a ResNet 18 object
    """
    return ResNet(BasicBlock, [2, 2, 2, 2])

def resnet34():
    """ return a ResNet 34 object
    """
    return ResNet(BasicBlock, [3, 4, 6, 3])

def resnet50():
    """ return a ResNet 50 object
    """
    return ResNet(BottleNeck, [3, 4, 6, 3])

def resnet101():
    """ return a ResNet 101 object
    """
    return ResNet(BottleNeck, [3, 4, 23, 3])

def resnet152():
    """ return a ResNet 152 object
    """
    return ResNet(BottleNeck, [3, 8, 36, 3])




对于50层只能降低1/3显存,但是对于152层则可以降低一倍以上的显存。

你可能感兴趣的:(pytoch,pytorch)