yolov4_u5版复现—3. 数据读入 dataset.py

1.首先需要了解一下Pytorch的DataLoader, DataSet, Sampler之间的关联

其中DataSet类来自from torch.utils.data import Dataset, 指针对输入数据提供读取方式,数据增强等处理手段。直接继承DataSet实现就可以,主要需要实现以下三个方法,构造函数定义数据处理需要的各种参数,def getitem(self, index)允许通过 dataset[index]的方式访问第index个数据,而对数据处理和增强的主要代码都写在这里,def len(self)允许通过len(dataset)返回dataset对象的长度,即样本个数。而这里的主要参数index就来自于DataLoader中的Sampler

class Dataset(object):
	def __init__(self):
		...
		
	def __getitem__(self, index):
		return ...
	
	def __len__(self):
		return ...

DataLoader的定义如下,pytorch中的数据加载模块 Dataloader,使用Sampler或者batch_sampler(都由生成器实现)来返回数据的索引,使用迭代器 来返回需要的张量数据( 其中输入为dataset和sampler获取的索引 indices,例如:batch = self.collate_fn([self.dataset[i] for i in indices])),可以在大量数据情况下,实现小批量循环迭代式的读取,避免了内存不足问题

关于迭代器,生成器等概念以及其在Dataloader中的使用可以参考这里

class DataLoader(object):
    def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
                 batch_sampler=None, num_workers=0, collate_fn=default_collate,
                 pin_memory=False, drop_last=False, timeout=0,
                 worker_init_fn=None)

关于Pytorch的DataLoader, DataSet, Sampler之间的关联,更详细的可以参考别人的博客:
点击这里,已经有人讲的很详细,就不多赘述,下面直接分析代码。
yolov4_u5版复现—3. 数据读入 dataset.py_第1张图片
创建dataloader 的代码如下:

def create_dataloader(path, img_size, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False, local_rank=-1, world_size=-1):
    """
    path:Txt file containing picture path or folder path containing picture
    img_size:Network input picture size
    batch_size: batch size
    stride:Maximum total stride size of down sampling
    opt:input parameter when use train.py in terminal
    hyp:hyper parameter, Here we mainly use some coefficients about data augmentation (rotation, translation, etc.)
    augment:data augmentation or not
    cache:Whether to cache pictures to memory in advance to speed up training
    pad:padding when setting shape of rectangle training
    rect:Rectangular validation or not
    local_rank: DDP mode or not
    world_size: Number of pc,  set in Multi pc and multi GPU
    """
    with torch_distributed_zero_first(local_rank):
        dataset = LoadImagesAndLabels(path, img_size, batch_size,
                                      augment=augment,  # augment images
                                      hyp=hyp,          # augmentation hyperparameters
                                      rect=rect,        # rectangular validation
                                      cache_images=cache,
                                      single_cls=opt.single_cls,
                                      stride=int(stride),
                                      pad=pad)

    batch_size = min(batch_size, len(dataset))
    num_workers = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, 8])
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if local_rank != -1 else None

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             sampler=train_sampler,
                                             pin_memory=True,
                                             collate_fn=LoadImagesAndLabels.collate_fn)
    return dataloader, dataset

(1).先了解上下文管理器的相关机制
关于@contextmanager 装饰器的主要机理可参考这里

其中local_rank=-1表示不是DDP模式,local_rank !=-1表示DDP模式,其中local_rank=0表示主进程,其他为从进程。

而在这里代码运行的流程如下:
a. 进入 torch_distributed_zero_first内部 -> b. 运行yield之前的代码 -> c. 返回with中,执行LoadImagesAndLabels() -> d. 进入torch_distributed_zero_first内部,接着运行yield后面的代码

with torch_distributed_zero_first(local_rank):
        dataset = LoadImagesAndLabels(path, img_size, batch_size,
                                      augment=augment,  # augment images
                                      hyp=hyp,          # augmentation hyperparameters
                                      rect=rect,        # rectangular validation
                                      cache_images=cache,
                                      single_cls=opt.single_cls,
                                      stride=int(stride),
                                      pad=pad)
@contextmanager
def torch_distributed_zero_first(local_rank: int):
    """
    Decorator to make all processes in distributed training wait for each local_master to do something.
    """
    if local_rank not in [-1, 0]:
        torch.distributed.barrier()
    yield
    if local_rank == 0:
        torch.distributed.barrier()

在上面的代码示例中,如果执行create_dataloader()函数的进程不是主进程,即rank不等于0或者-1,上下文管理器会执行相应的torch.distributed.barrier(),设置一个阻塞栅栏,让此进程处于等待状态,等待所有进程到达栅栏处(包括主进程数据处理完毕);如果执行create_dataloader()函数的进程是主进程,其会直接去读取数据并处理,然后其处理结束之后会接着遇到torch.distributed.barrier(),此时,所有进程都到达了当前的栅栏处,这样所有进程就达到了同步,并同时得到释放,这里讲解来自视觉弘毅

(2)关于torch.utils.data.DataLoade的参数
num_workers即处理输入数据的进程数,一般根据cpu的性能来设置
train_sampler 如果是DDP模式需要设置DistributedSampler,否则不设置
pin_memory=True,这里可以简单的理解为加速。
collate_fn主要是用来将一个batch中的数据打包在一起。 batch = self.collate_fn([self.dataset[i] for i in indices])

这里关于参数需要注意的是DataLoader的部分初始化参数之间存在互斥关系:

如果你自定义了batch_sampler,那么这些参数都必须使用默认值:batch_size, shuffle,sampler,drop_last.
如果你自定义了sampler,那么shuffle需要设置为False
如果sampler和batch_sampler都为None,那么batch_sampler使用Pytorch已经实现好的BatchSampler,而sampler分两种情况:
若shuffle=True,则sampler=RandomSampler(dataset)
若shuffle=False,则sampler=SequentialSampler(dataset)

	batch_size = min(batch_size, len(dataset))
    num_workers = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, 8])
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if local_rank != -1 else None

    dataloader = torch.utils.data.DataLoader(dataset,
                                             batch_size=batch_size,
                                             num_workers=num_workers,
                                             sampler=train_sampler,
                                             pin_memory=True,
                                             collate_fn=LoadImagesAndLabels.collate_fn)
  1. class LoadImagesAndLabels(Dataset)
    (1) 构造函数,定义参数和部分预处理
    def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
                 cache_images=False, single_cls=False, stride=32, pad=0.0):
        # --------------------------------------------------------------------------------------------------------------
        # define img_files(images path list)      保存图像路径列表
        try:
            files = []        # images path list
            for p in path if isinstance(path, list) else [path]:

                parent_dir = str(Path(p).parent) + os.sep
                dir = str(Path(p))

                if os.path.isfile(dir):  # file
                    with open(dir, 'r') as f:
                        img_dir_list = f.read().splitlines()
                        files += [i.replace('./', parent_dir) if i.startswith('./') else i for i in img_dir_list]
                elif os.path.isdir(dir):  # folder
                    files += glob.iglob(dir + os.sep + '*.*')    # [] + generater 将遍历迭代器中所有元素并添加到[]中
                else:
                    raise Exception('Not exist file dir: %s.' % dir)
            self.img_files = sorted([i.replace('/', os.sep)
                                     for i in files if os.path.splitext(i)[-1].lower() in img_formats])
        except Exception as e:
            raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))

        num_files = len(self.img_files)
        assert num_files > 0, 'No images found in %s. See %s' % (path, help_url)

        # --------------------------------------------------------------------------------------------------------------
        # define label_files(labels path list)      保存标签路径列表
        self.label_files = [i.replace('images', 'labels').replace(os.path.splitext(i)[-1], 'txt')
                            for i in self.img_files]

        # --------------------------------------------------------------------------------------------------------------
        # define batch index, number of batches      后续Rectangular train/val/test/detection  会用
        batch_index = np.array([i // batch_size for i in range(num_files)], dtype=np.int)
        num_batches = batch_index[-1] + 1

        # --------------------------------------------------------------------------------------------------------------
        # parameters
        self.num_files = num_files
        self.batch = batch_index    # batch index of image
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp                    # 超参数,这里主要是图像空间增强(平移缩放旋转错切)等参数和色域增强(hsv)
        self.image_weights = image_weights
        self.rect = False if image_weights else rect              # Rectangular train 和 mosaic train只能二选一
        self.mosaic = self.augment and not self.rect   # load 4 images at a time into a mosaic (only during training)
        self.mosaic_border = [-img_size // 2, -img_size // 2]         # mosaic增强时使用
        self.stride = stride

        # --------------------------------------------------------------------------------------------------------------
        # Check cache   cache[img]=[label, shape]        首次读入label则缓存label到 label.cache中
        cache_dir = str(Path(self.label_files[0]).parent) + '.cache'   # cached labels
        if os.path.isfile(cache_dir):
            cache = torch.load(cache_dir)
            if cache['hash'] != get_hash(self.img_files + self.label_files):  # dataset changed
                cache = self.cache_labels(cache_dir)  # re-cache
        else:
            cache = self.cache_labels(cache_dir)    # cache

        # --------------------------------------------------------------------------------------------------------------
        # Get labels
        labels, shapes = zip(*[cache[img] for img in self.img_files])

        self.shapes = np.array(shapes, dtype=np.float64)
        self.labels = labels

        # --------------------------------------------------------------------------------------------------------------
        # Rectangular Training  https://github.com/ultralytics/yolov3/issues/232
        if self.rect:
            s = self.shapes                              # 这里是用Image获取的shape,为w,h,cv2读入的shape顺序为h,w
            aspect_ratio = s[:, 1] / s[:, 0]    # h/w    # 将所有数据按照图像aspect ratio从小到大排列
            aspect_ratio_index = np.argsort(aspect_ratio)

            shape_batch = np.array([[1, 1]] * num_batches)

            self.labels = self.labels[aspect_ratio_index]
            self.img_files = self.img_files[aspect_ratio_index]
            self.label_files = self.label_files[aspect_ratio_index]
            aspect_ratio = aspect_ratio[aspect_ratio_index]

            for batch in range(num_batches):
                aspect_ratio_batch = aspect_ratio[batch_index == batch]   # 获取一个batch的aspect_ratio
                ar_min = aspect_ratio_batch.min()
                ar_max = aspect_ratio_batch.max()
                # 计算batch中ar的最小值和最大值,如果最大值比1还小,即所有图像的h都小于w,则将整个batch的shape设置为h_batch,w_batch=(h/w, 1)
                # 如果最小值最1还大,即所有图形的h都大于w,则将整个batch的shape设置为h_batch,w_batch=(1, w/h)
                # 如果不符合以上条件则h_batch,w_batch=(1, 1)
                if ar_max < 1:
                    shape_batch[batch] = [ar_max, 1]
                elif ar_min > 1:
                    shape_batch[batch] = [1, 1 / ar_min]
            #  使用输入的img_size获取每个batch的图像尺寸,作为 Rectangular Training需要的shape,确保为stride的倍数
            self.batch_shapes = [np.ceil(shape_batch[i] * img_size / stride).astype(np.int) * stride
                                 for i in range(num_batches)]

        # --------------------------------------------------------------------------------------------------------------
        # Cache labels
        num_missing, num_found, num_empty, num_duplicate = 0, 0, 0, 0  # number missing, found, empty, duplicate
        pbar = tqdm(self.label_files)

        for i, file in enumerate(pbar):
            l = self.shapes[i]
            if l.shape[0]:
                assert l.shape[1] == 5, '> 5 label columns: %s' % file
                assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels: %s' % file
                assert (l >= 0).all(), 'negative labels: %s' % file

                if np.unique(l, axis=0).shape[0] < l.shape[0]:  # duplicate rows
                    num_duplicate += 1
                if single_cls:
                    l[:, 0] = 0   # force dataset into single-class mode
                self.shapes[i] = l
                num_found += 1
            else:
                num_empty += 1
            pbar.desc = 'Scanning labels %s (%g found, %g missing, %g empty, %g duplicate, for %g images)' % (
                cache_dir, num_found, num_missing, num_empty, num_duplicate, num_files)

        if num_found == 0:
            s = 'WARNING: No labels found in %s. See %s' % (os.path.dirname(cache_dir) + os.sep, help_url)
            print(s)
            assert not augment, '%s. Can not train without labels.' % s

        # --------------------------------------------------------------------------------------------------------------
        # Cache images into memory for faster training (WARNING: large datasets may exceed system RAM)
        self.imgs = [None] * num_files
        if cache_images:
            gb = 0  # Gigabytes of cached images
            pbar = tqdm(range(len(self.img_files)), desc='Caching images')
            self.img_hw0, self.img_hw = [None] * num_files, [None] * num_files
            for i in pbar:
                self.imgs[i], self.img_hw0[i], self.img_hw[i] = self.load_image(i)
                gb += self.imgs[i].nbytes
                pbar.desc = 'Caching images (%.1fGB)' % (gb / 1E9)

        # --------------------------------------------------------------------------------------------------------------

(2)def getitem(self, index), 图像增强

    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, index):
        # --------------------------------------------------------------------------------------------------------------
        # data augmentation,
        # if train, use Mosaic or Rectangular(when self.rect=true),
        # if val or test, use Rectangular, letterbox(but not minimum rectangle)   
        # if detection, use Rectangular, letterbox(minimum rectangle)   只有推理时才使用最小矩形框,加速推理
        if self.mosaic:
            img, label = self.load_mosaic(index)
            shapes = None
        else:
            img, (h0, w0), (h, w) = self.load_image(index)
            shape = self.batch_shapes[self.batch[index]] if self.rect else self.img_size    # final letterboxed shape
            img, ratio, (dw, dh) = letterbox(img, shape, auto=False, scaleup=self.augment)

            shapes = (h0, w0), ((h / h0, w / w0), (dw, dh))  # for COCO mAP rescaling  ??????

            # according to letterbox change corresponding label
            x = self.labels[index]
            label = x.copy()
            if x.size > 0:
                # Normalized xywh to  xywh     ratio * w, ratio * h表示rect后的矩形框尺寸,dw和dh表示图像四周填充的边框
                label[:, 1] = ratio * w * x[:, 1] + dw  # pad width
                label[:, 2] = ratio * h * x[:, 2] + dh  # pad height
                label[:, 3] = ratio * w * x[:, 3]
                label[:, 4] = ratio * h * x[:, 4]

        # --------------------------------------------------------------------------------------------------------------
        # image augmentation, image space    缩放,平移,旋转,错切
        if self.augment:
            if not self.mosaic:
                img, label = random_affine(img, label, self.hyp['degrees'], self.hyp['translate'],
                                           self.hyp['scale'], self.hyp['shear'])

        # --------------------------------------------------------------------------------------------------------------
        # image augmentation, color space    色域转换
            augment_hsv(img, self.hyp['hsv_h'], self.hyp['hsv_s'], self.hyp['hsv_v'])
        # --------------------------------------------------------------------------------------------------------------
        # label normalization
        if label.size > 0:
            label[:, 1] = label[:, 1] / img.shape[1]   # x
            label[:, 2] = label[:, 2] / img.shape[0]   # y
            label[:, 3] = label[:, 3] / img.shape[1]   # w
            label[:, 4] = label[:, 4] / img.shape[0]   # h

        # --------------------------------------------------------------------------------------------------------------
        # image augmentation, flip  
        if self.augment:
            # random left-right flip
            lr_flip = True
            if lr_flip and random.random() < 0.5:
                img = np.fliplr(img)
                if label.size > 0:
                    label[:, 1] = 1 - label[:, 1]

            # random up-down flip
            ud_flip = False
            if ud_flip and random.random() < 0.5:
                img = np.flipud(img)
                if label.size > 0:
                    label[:, 2] = 1 - label[:, 2]

        # --------------------------------------------------------------------------------------------------------------
        # output    labels新增一列用来存储在每个batch中每个图像的索引
        num_label = len(label)
        labels = torch.zeros((num_label, 6))
        if num_label:
            labels[:, 1:] = torch.from_numpy(label)

        img = img[:, :, ::-1].transpose(2, 0, 1)   # BGR to RGB, (h, w, 3) to (3, h, w)
        img = np.ascontiguousarray(img)
        #此处的返回值为单张图像数据,一个打包成batch的操作在collate_fn中完成
        return torch.from_numpy(img), labels, self.img_files[index], shapes
        # --------------------------------------------------------------------------------------------------------------

(3) def collate_fn(batch), torch.utils.data.DataLoade中负责打包一个batch数据的函数,当_getitem__返回的数据不止image,label时,需要重写。

	@staticmethod
    def collate_fn(batch):
        # merge a batch images and labels
        batch_imgs, batch_labels, batch_paths, batch_shapes = zip(*batch)
        for i, label in enumerate(batch_labels):
            if label.size[0]:
                label[:, 0] = i
        return torch.stack(batch_imgs, dim=0), torch.cat(batch_labels, dim=0), batch_paths, batch_shapes

(4)mosaic

    def load_mosaic(self, index):
        # mosaic for image , mosaic means combining four images to a new image
        # --------------------------------------------------------------------------------------------------------------
        label_mosaic = []     # label after mosaic
        indices = [index] + [random.randint(0, self.num_files-1) for _ in range(3)]   # current image and other 3 images

        # --------------------------------------------------------------------------------------------------------------
        # (x_cross, y_cross) is cross point of four images,   四张图像合并的交点
        # Range(0.5*img_mosaic_w ~ 1.5*img_mosaic_w, 0.5*img_mosaic_h ~ 1.5*img_mosaic_h)
        y_cross, x_cross = [int(random.uniform(-x, 2 * self.img_size + x)) for x in self.mosaic_border]

        # --------------------------------------------------------------------------------------------------------------
        # put four small pictures into large mosaic picture
        for i, index in enumerate(indices):
            # ----------------------------------------------------------------------------------------------------------
            img, _, (h, w) = self.load_image(index)  # current image

            # (x1_mosaic, y1_mosaic, x2_mosaic, y2_mosaic) means (x_min,y_min,x_max,y_max) in mosaic image
            # (x1, y1, x2, y2) means (x_min,y_min,x_max,y_max) in one of four small images
            if i == 0:  # upper left image
                # ------------------------------------------------------------------------------------------------------
                # define mosaic image
                img_mosaic = np.full((2 * self.img_size, 2 * self.img_size, img.shape[2]), 114, dtype=np.uint8)

                # ------------------------------------------------------------------------------------------------------
                x1_mosaic, y1_mosaic, x2_mosaic, y2_mosaic = max(x_cross-w, 0), max(y_cross-h, 0), x_cross, y_cross
                x1, y1, x2, y2 = w - (x2_mosaic-x1_mosaic), h - (y2_mosaic-y1_mosaic), w, h
            if i == 1:  # upper right image
                x1_mosaic, y1_mosaic, x2_mosaic, y2_mosaic = x_cross, max(y_cross-h, 0), \
                                                             min(x_cross+w, 2*self.img_size), y_cross
                 # 这里原来的代码中 x2 = min(w,  x2_mosaic-x1_mosaic), 我认为x2_mosaic-x1_mosaic永远不会大于w,所以
                 #没有必要这样设置,直接x2 =x2_mosaic-x1_mosaic就好,后面有好几处也是这个道理
                x1, y1, x2, y2 = 0, h - (y2_mosaic-y1_mosaic), x2_mosaic-x1_mosaic, h
            if i == 2:  # bottom left image
                x1_mosaic, y1_mosaic, x2_mosaic, y2_mosaic = max(x_cross-w, 0), y_cross, \
                                                             x_cross, min(y_cross+h, 2*self.img_size)
                x1, y1, x2, y2 = w-(x2_mosaic-x1_mosaic), 0, w, y2_mosaic-y1_mosaic
            if i == 3:  # bottom right image
                x1_mosaic, y1_mosaic, x2_mosaic, y2_mosaic = x_cross, y_cross, min(x_cross+w, 2*self.img_size), \
                                                             min(y_cross+h, 2*self.img_size)
                x1, y1, x2, y2 = 0, 0, x2_mosaic-x1_mosaic, y2_mosaic-y1_mosaic

            img_mosaic[y1_mosaic:y2_mosaic, x1_mosaic:x2_mosaic] = img[y1:y2, x1:x2]

            # ----------------------------------------------------------------------------------------------------------
            label = self.labels[index]

            if label.size > 0:
                # normalized xywh to xyxy
                x = label.copy()
                label[:, 1] = (x[:, 1] - x[:, 3] / 2) * w
                label[:, 2] = (x[:, 2] - x[:, 4] / 2) * h
                label[:, 3] = (x[:, 1] + x[:, 3] / 2) * w
                label[:, 4] = (x[:, 2] + x[:, 4] / 2) * h
                # 利用(x1,y1)和 (x1_mosaic,y1_mosaic )是相同的顶点来转换坐标系
                # distance between label and (x1, y1)
                delta_x1 = label[:, 1] - x1
                delta_y1 = label[:, 2] - y1
                delta_x2 = label[:, 3] - x1
                delta_y2 = label[:, 4] - y1

                # convert label coordinate system from img to img_mosaic
                label[:, 1] = delta_x1 + x1_mosaic
                label[:, 2] = delta_y1 + y1_mosaic
                label[:, 3] = delta_x2 + x1_mosaic
                label[:, 4] = delta_y2 + y1_mosaic

            label_mosaic.append(label)

        # --------------------------------------------------------------------------------------------------------------
        if len(label_mosaic):
            label_mosaic = np.concatenate(label_mosaic, axis=0)

            # clip label_mosaic
            np.clip(label_mosaic[:, 1:], 0, 2 * self.img_size, out=label_mosaic[:, 1:])

        # ----------------------------------------------------------------------------------------------------------
        # Augment
        img_mosaic, label_mosaic = random_affine(img_mosaic, label_mosaic,
                                      degrees=self.hyp['degrees'],
                                      translate=self.hyp['translate'],
                                      scale=self.hyp['scale'],
                                      shear=self.hyp['shear'],
                                      border=self.mosaic_border)  # border to remove
        return img_mosaic, label_mosaic

(5)load image

    def load_image(self, index):
        # loads 1 image from dataset, returns img, original hw, resized hw
        img = self.imgs[index]
        if img is None:            # not cached
            path = self.img_files[index]
            img0 = cv2.imread(path)  # BGR
            assert img0 is not None, 'Image Not Found ' + path

            h0, w0 = img0.shape[0:2]   # orig hw
            ratio = self.img_size / max(h0, w0)   # resize image to img_size

            if ratio != 1:   # always resize down, only resize up if training with augmentation
                interp = cv2.INTER_AREA if ratio < 1 and not self.augment else cv2.INTER_LINEAR
                h, w = int(h0 * ratio), int(w0 * ratio)
                img = cv2.resize(img0, (w, h), interpolation=interp)
            else:
                h, w = h0, w0
                img = img0
            return img, (h0, w0), (h, w)     # img, hw_original, hw_resized
        else:
            return img, self.img_hw0[index], self.img_hw[index]   # img, hw_original, hw_resized

(6) letterbox, yolo v5特有的矩形训练/测试/推理,其中推理时采用最小矩形框,可减少计算量,增加速度。
关于矩形推理的原理可以参考江大白在yolo v5讲解中的说明

def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=False, scaleFill=False, scaleup=True):
    shape = img.shape[:2]   # h,w

    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)   # h,w

    # Scale ratio (new / old)
    ratio = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:   # only scale down for val/test/detection
        ratio = min(ratio, 1.0)

    # pad: image padding in w and h
    new_shape_pad = int(round(shape[1] * ratio)), int(round(shape[0] * ratio))     # w,h
    pad = new_shape[1] - new_shape_pad[0], new_shape[0] - new_shape_pad[1]   # w,h

    if auto:   # minimum rectangle
        pad = np.mod(pad[0], 32), np.mod(pad[1], 32)

    pad /= 2    # divide padding into 2 sides

    if shape[::-1] != new_shape_pad:
        img = cv2.resize(img, new_shape_pad, interpolation=cv2.INTER_LINEAR)

    top, bottom = int(round(pad[1] - 0.1)), int(round(pad[1] + 0.1))
    left, right = int(round(pad[0] - 0.1)), int(round(pad[0] + 0.1))

    img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType=cv2.BORDER_CONSTANT, value=color)

    return img, ratio, (pad[0], pad[1])

你可能感兴趣的:(pytorch,深度学习,python,计算机视觉,神经网络)