einops.rearrange、repeat、reduce 对数据维度进行操作

支持numpy和torch

目录

1.einops.rearrange 重新指定维度

2.einops.repeat 重排和重复(增加)维度

3.einops.reduce


1.einops.rearrange 重新指定维度


def rearrange(tensor, pattern, **axes_lengths):
einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, stack, concatenate and other operations.

其中:括号代表合并

import torch

if __name__ == '__main__':

      # data = torch.arange(0,27,1)
      # print(data)

      import numpy as np
      from einops import rearrange, repeat
      # suppose we have a set of 32 images in "h w c" format (height-width-channel)
      images = [np.random.randn(30, 40, 3) for _ in range(32)]
      print("data shape",len(images),images[0].shape)
      # stack along first (batch) axis, output is a single array :(32, 30, 40, 3)
      print(rearrange(images, 'b h w c -> b w h c').shape)
      # concatenate images along height (vertical axis), 960 = 32 * 30 :(960, 40, 3)
      print(rearrange(images, 'b h w c -> (b h) w c').shape)
      # concatenated images along horizontal axis, 1280 = 32 * 40 :(30, 1280, 3)
      print(rearrange(images, 'b h w c -> h (b w) c').shape)
      # reordered axes to "b c h w" format for deep learning :(32, 3, 30, 40)
      print(rearrange(images, 'b h w c -> b c h w').shape)
      # flattened each image into a vector, 3600 = 30 * 40 * 3 :(32, 3600)
      print(rearrange(images, 'b h w c -> b (c h w)').shape)
      print(rearrange(images, 'b h w (c ph) -> b (c h) (w ph)',ph=1).shape)


# ======================================================================================================================
# 这里(h h1) (w w1)就相当于h与w变为原来的1/h1,1/w1倍
 
# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2  :(128, 15, 20, 3)
print(rearrange(images, 'b (h h1) (w w1) c -> (b h1 w1) h w c', h1=2, w1=2).shape)
 
# space-to-depth operation  :(32, 15, 20, 12)
print(rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape)


2.einops.repeat 重排和重复(增加)维度


einops.repeat allows reordering elements and repeating them in arbitrary combinations. This operation includes functionality of repeat, tile, broadcast functions.

import numpy as np
from einops import rearrange, repeat,reduce
 
# a grayscale image (of shape height x width)
image = np.random.randn(30, 40)
 
# change it to RGB format by repeating in each channel:(30, 40, 3)
print(repeat(image, 'h w -> h w c', c=3).shape)
 
# repeat image 2 times along height (vertical axis):(60, 40)
print(repeat(image, 'h w -> (repeat h) w', repeat=2).shape)
 
# repeat image 2 time along height and 3 times along width:(30, 120)
print(repeat(image, 'h w -> h (repeat w)', repeat=3).shape)
 
# convert each pixel to a small square 2x2. Upsample image by 2x:(60, 80)
print(repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape)
 
# pixelate image first by downsampling by 2x, then upsampling:(30, 40)
downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2)
print(repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape)


3.einops.reduce


einops.reduce provides combination of reordering and reduction using reader-friendly notation.

import numpy as np
from einops import rearrange,reduce
 
x = np.random.randn(100, 32, 64)
# perform max-reduction on the first axis:(32, 64)
print(reduce(x, 't b c -> b c', 'max').shape) 
 
# same as previous, but with clearer axes meaning:(32, 64)
print(reduce(x, 'time batch channel -> batch channel', 'max').shape)
 
x = np.random.randn(10, 20, 30, 40)
# 2d max-pooling with kernel size = 2 * 2 for image processing:(10, 20, 15, 20)
y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
print(y1.shape)
 
# if one wants to go back to the original height and width, depth-to-space trick can be applied:(10, 5, 30, 40)
y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
print(y2.shape)
 
# Adaptive 2d max-pooling to 3 * 4 grid:(10, 20, 3, 4)
print(reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape)
 
# Global average pooling:(10, 20)
print(reduce(x, 'b c h w -> b c', 'mean').shape)


————————————————
版权声明:本文为CSDN博主「马鹏森」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/weixin_43135178/article/details/118877384

你可能感兴趣的:(数据结构与算法,einops,rearrange)