【Pytorch】自己的工具类:TensorList

class TensorList(object):
    def __init__(self, tensors):
        self.tensors = tensors
        self.len = [len(t) for t in tensors]

    def __len__(self):
        return len(self.tensors)

    def tensors_len(self):
        return torch.LongTensor([len(t) for t in self.tensors]).to(self.tensors[0].device)

    def to(self, *args, **kwargs):
        tensors = [tensor.to(*args, **kwargs) for tensor in self.tensors]
        return TensorList(tensors)

    def cat(self):
        return torch.cat(self.tensors)

    def split(self, inst_embeddings):
        return torch.split(inst_embeddings, self.len, dim=0)

使用

def xyxy2xcycwh(xyxy):
	"""
	xyxy: shape=(n, 4)
	"""
    x1, y1, x2, y2 = xyxy
    xc, yc = (x1+x2)/2, (y1+y2)/2
    w, h = x2-x1, y2-y1
    return torch.stack([xc, yc, w, h], dim=0)

import torch
img1_bbox = torch.zeros((1, 4)) + 1
img2_bbox = torch.zeros((3, 4)) + 2
img3_bbox = torch.zeros((2, 4)) + 3

imgs_bbox = TensorList([img1_bbox, img2_bbox , img3_bbox ])

cat_imgs_bbox = imgs_bbox.cat()  # for loop too slow when too many list
cat_imgs_bbox = xyxy2xcycwh(cat_imgs_bbox)
print(cat_imgs_bbox.shape)
img1_bbox,img2_bbox,img3_bbox = imgs_bbox.split(cat_imgs_bbox) 
print(img1_bbox)
print(img2_bbox)
print(img3_bbox)

你可能感兴趣的:(pytorch,pytorch,python,深度学习)