torch conv2d 卷积 底层复现 验证算法实现

Conv2d

卷积是什么?

torch conv2d 卷积 底层复现 验证算法实现_第1张图片

这里提供torch conv2d算法的两种python的实现方式. (含dilation, groups)

  1. 使用im2col的方式实现 (更易于实现, 更快的速度)
  2. 使用循环方式实现 (更易于理解)

代码地址: https://github.com/Jintao-Huang/ml_alg/blob/main/libs/ml/_ml_alg/_nn_functional.py

以下代码含测试代码:

import torch
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, Tuple


def conv2d(
    x: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    stride: Tuple[int, int] = (1, 1),
    padding: Tuple[int, int] = (0, 0),
    dilation: Tuple[int, int] = (1, 1),
    groups: int = 1
) -> Tensor:
    """faster than conv2d_2, but more memory. (recommend)
    x: [N, Cin, Hin, Win]
    weight: [Cout, Cin//G, KH, KW]. 
    bias: [Cout]
    stride: SH, SW
    padding: PH, PW
    return: [N, Cout, Hout, Wout]
    """
    Hin, Win = x.shape[2:]
    DH, DW = dilation
    G = groups
    KH, KW = weight.shape[2:]
    KH_D, KW_D = (KH - 1) * DH + 1, (KW - 1) * DW + 1
    PH, PW = padding
    SH, SW = stride
    N, Cin = x.shape[:2]
    Cout = weight.shape[0]
    # Out = (In + 2*P − ((K-1)*D+1)) // S + 1
    Hout, Wout = (Hin + 2 * PH - KH_D) // SH + 1, (Win + 2 * PW - KW_D) // SW + 1
    assert weight.shape[1] * G == Cin
    assert Cout % G == 0
    # [N, Cin, Hin, Win] -> [N, Cin*KH*KW, Hout*Wout] -> [N, G, Cin//G, KH*KW, Hout*Wout]
    x = F.unfold(x, (KH, KW), (DH, DW), (PH, PW), (SH, SW))
    x = x.view(N, G, Cin//G, KH*KW, Hout*Wout)
    #
    weight = weight.contiguous().view(G, Cout // G, Cin//G, KH*KW)
    # [N, G, Cin//G, KH*KW, Hout*Wout], [G, Cout//G, Cin//G, KH*KW] ->
    #   [N, G, Cout//G, Hout*Wout] -> [N, Cout, Hout, Wout]
    res = torch.einsum("abcde,bfcd->abfe", x, weight).contiguous().view(N, Cout, Hout, Wout)
    #
    if bias is not None:
        res.add_(bias[None, :,  None, None])
    return res


def conv2d_2(
    x: Tensor,
    weight: Tensor,
    bias: Optional[Tensor] = None,
    stride: Tuple[int, int] = (1, 1),
    padding: Tuple[int, int] = (0, 0),
    dilation: Tuple[int, int] = (1, 1),
    groups: int = 1
) -> Tensor:
    """
    x: [N, Cin, Hin, Win]
    weight: [Cout, Cin//G, KH, KW]. 
    bias: [Cout]
    stride: SH, SW
    padding: PH, PW
    return: [N, Cout, Hout, Wout]
    """
    if padding != (0, 0):
        x = F.pad(x, [padding[1], padding[1], padding[0], padding[0]])  # lrtb
    Hin, Win = x.shape[2:]
    DH, DW = dilation
    G = groups
    KH, KW = weight.shape[2:]
    KH_D, KW_D = (KH - 1) * DH + 1, (KW - 1) * DW + 1
    SH, SW = stride
    N, Cin = x.shape[:2]
    Cout = weight.shape[0]
    assert weight.shape[1] * G == Cin
    assert Cout % G == 0
    # Out = (In + 2*P − ((K-1)*D+1)) // S + 1. (P, D已经在In, K中算进去了)
    Hout, Wout = (Hin - KH_D) // SH + 1, (Win - KW_D) // SW + 1
    x = x.contiguous().view(N, G, Cin//G, Hin, Win)
    weight = weight.contiguous().view(G, Cout // G, Cin//G, KH, KW)
    res = []
    for i in range(Hout):
        for j in range(Wout):
            h_start, w_start = i * SH, j * SW
            h_pos, w_pos = slice(h_start, (h_start + KH_D), DH), \
                slice(w_start, (w_start + KW_D), DW)
            # [N, G, Cin//G, KH, KW], [G, Cout//G, Cin//G, KH, KW] -> [N, G, Cout//G] -> [N, Cout]
            res.append(torch.einsum("abcde,bfcde->abf", x[:, :, :, h_pos, w_pos], weight))
    res = torch.stack(res, dim=-1).view(N, Cout, Hout, Wout)
    if bias is not None:
        res.add_(bias[None, :,  None, None])
    return res


if __name__ == "__main__":
    torch.backends.cudnn.deterministic = True
    x = torch.randn(16, 128, 112, 112, device="cuda")
    w = torch.randn(256, 128, 3, 3, device="cuda")
    b = torch.randn(256, device="cuda")
    y1 = F.conv2d(x, w, b, (1, 1), (1, 1), (2, 2), 1)
    y2 = conv2d(x, w, b, (1, 1), (1, 1), (2, 2), 1)
    y3 = conv2d_2(x, w, b, (1, 1), (1, 1), (2, 2), 1)
    print(torch.allclose(y1, y2, atol=1e-3))
    print(torch.allclose(y2, y3, atol=1e-3))

    x = torch.randn(16, 128, 112, 112, device="cuda")
    w = torch.randn(256, 1, 3, 3, device="cuda")
    b = torch.randn(256, device="cuda")
    y1 = F.conv2d(x, w, b, (1, 1), (1, 1), (2, 2), 128)
    y2 = conv2d(x, w, b, (1, 1), (1, 1), (2, 2), 128)
    y3 = conv2d_2(x, w, b, (1, 1), (1, 1), (2, 2), 128)
    print(torch.allclose(y1, y2, atol=1e-3))
    print(torch.allclose(y1, y3, atol=1e-3))

你可能感兴趣的:(深度学习,算法,深度学习,python,pytorch,卷积神经网络)