pytorch 中 view 和reshape的区别

pytorch 中 view 和reshape的区别_第1张图片

 

在 PyTorch(一个流行的深度学习框架)中,

reshapeview 都是用于改变张量(tensor)形状的方法,但它们在实现方式和使用上有一些区别。下面是它们之间的主要区别:

  1. 实现方式:

    • reshape: reshape 方法创建一个新的张量,其元素与原始张量共享内存空间。这意味着改变形状后,原始张量和新的张量将共享相同的数据存储,所以在一个张量上的修改会影响到另一个张量。
    • view: view 方法并不会创建一个新的张量,而是返回一个与原始张量共享数据存储的新视图(view)。如果原始张量和新的视图张量上的元素被修改,它们会互相影响,因为它们共享相同的数据。
  2. 支持条件:

    • reshape: 可以用于任意形状的变换,但需要确保变换前后元素总数保持一致,否则会抛出错误。
    • view: 只能用于支持大小相同的变换,也就是变换前后元素总数必须保持不变。这是因为 view 并不改变数据的存储,所以必须保持数据总量不变,否则会抛出错误。
  3. 内存连续性

    • reshape: 不保证新张量在内存中的连续性,即可能导致新张量的元素在内存中的存储顺序与原始张量不同。
    • view: 如果原始张量在内存中是连续存储的,那么新视图张量也会保持连续性,否则会返回一个不连续的张量。
  4. 是否支持自动计算维度:

    • reshape: 可以通过将某个维度指定为-1,让 PyTorch 自动计算该维度的大小。
    • view: 不支持将维度指定为-1,需要手动计算新视图张量的大小。当对不连续的张量进行形状变换时,PyTorch 会自动将其复制为连续的张量,这可能会导致额外的内存开销。为了避免这种情况,你可以使用 contiguous() 方法将张量变为连续的。例如:x.contiguous().view(3, 4)
    import torch
    
    # 原始张量
    x = torch.arange(12)
    
    # 使用 reshape
    x_reshaped = x.reshape(3, 4)  # 创建一个新的形状为(3, 4)的张量
    x_reshaped[0, 0] = 100  # 修改新张量的元素会影响到原始张量
    print(x)  # 输出 tensor([100,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11])
    print(x_reshaped)  # 输出 tensor([[100,   1,   2,   3], [  4,   5,   6,   7], [  8,   9,  10,  11]])
    
    # 使用 view
    x_viewed = x.view(3, 4)  # 创建一个新的形状为(3, 4)的张量视图
    x_viewed[0, 1] = 200  # 修改视图张量的元素会影响到原始张量
    print(x)  # 输出 tensor([100, 200,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11])
    print(x_viewed)  # 输出 tensor([[100, 200,   2,   3], [  4,   5,   6,   7], [  8,   9,  10,  11]])
    
    # 使用 view 自动计算维度大小
    x_auto_viewed = x.view(3, -1)  # 可以将某个维度指定为-1,让 PyTorch 自动计算大小
    print(x_auto_viewed)  # 输出 tensor([[100, 200,   2,   3], [  4,   5,   6,   7], [  8,   9,  10,  11]])
    
    # 由于 x_auto_viewed 是连续的,所以修改它也会影响原始张量 x
    x_auto_viewed[2, 2] = 300
    print(x)  # 输出 tensor([100, 200,   2,   3,   4,   5,   6,   7,   8,   9, 300,  11])
    

你可能感兴趣的:(CV基础知识,pytorch,人工智能,python)