卷积是什么?
这里提供torch conv2d算法的两种python的实现方式. (含dilation, groups)
代码地址: https://github.com/Jintao-Huang/ml_alg/blob/main/libs/ml/_ml_alg/_nn_functional.py
以下代码含测试代码:
import torch
import torch.nn.functional as F
from torch import Tensor
from typing import Optional, Tuple
def conv2d(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
groups: int = 1
) -> Tensor:
"""faster than conv2d_2, but more memory. (recommend)
x: [N, Cin, Hin, Win]
weight: [Cout, Cin//G, KH, KW].
bias: [Cout]
stride: SH, SW
padding: PH, PW
return: [N, Cout, Hout, Wout]
"""
Hin, Win = x.shape[2:]
DH, DW = dilation
G = groups
KH, KW = weight.shape[2:]
KH_D, KW_D = (KH - 1) * DH + 1, (KW - 1) * DW + 1
PH, PW = padding
SH, SW = stride
N, Cin = x.shape[:2]
Cout = weight.shape[0]
# Out = (In + 2*P − ((K-1)*D+1)) // S + 1
Hout, Wout = (Hin + 2 * PH - KH_D) // SH + 1, (Win + 2 * PW - KW_D) // SW + 1
assert weight.shape[1] * G == Cin
assert Cout % G == 0
# [N, Cin, Hin, Win] -> [N, Cin*KH*KW, Hout*Wout] -> [N, G, Cin//G, KH*KW, Hout*Wout]
x = F.unfold(x, (KH, KW), (DH, DW), (PH, PW), (SH, SW))
x = x.view(N, G, Cin//G, KH*KW, Hout*Wout)
#
weight = weight.contiguous().view(G, Cout // G, Cin//G, KH*KW)
# [N, G, Cin//G, KH*KW, Hout*Wout], [G, Cout//G, Cin//G, KH*KW] ->
# [N, G, Cout//G, Hout*Wout] -> [N, Cout, Hout, Wout]
res = torch.einsum("abcde,bfcd->abfe", x, weight).contiguous().view(N, Cout, Hout, Wout)
#
if bias is not None:
res.add_(bias[None, :, None, None])
return res
def conv2d_2(
x: Tensor,
weight: Tensor,
bias: Optional[Tensor] = None,
stride: Tuple[int, int] = (1, 1),
padding: Tuple[int, int] = (0, 0),
dilation: Tuple[int, int] = (1, 1),
groups: int = 1
) -> Tensor:
"""
x: [N, Cin, Hin, Win]
weight: [Cout, Cin//G, KH, KW].
bias: [Cout]
stride: SH, SW
padding: PH, PW
return: [N, Cout, Hout, Wout]
"""
if padding != (0, 0):
x = F.pad(x, [padding[1], padding[1], padding[0], padding[0]]) # lrtb
Hin, Win = x.shape[2:]
DH, DW = dilation
G = groups
KH, KW = weight.shape[2:]
KH_D, KW_D = (KH - 1) * DH + 1, (KW - 1) * DW + 1
SH, SW = stride
N, Cin = x.shape[:2]
Cout = weight.shape[0]
assert weight.shape[1] * G == Cin
assert Cout % G == 0
# Out = (In + 2*P − ((K-1)*D+1)) // S + 1. (P, D已经在In, K中算进去了)
Hout, Wout = (Hin - KH_D) // SH + 1, (Win - KW_D) // SW + 1
x = x.contiguous().view(N, G, Cin//G, Hin, Win)
weight = weight.contiguous().view(G, Cout // G, Cin//G, KH, KW)
res = []
for i in range(Hout):
for j in range(Wout):
h_start, w_start = i * SH, j * SW
h_pos, w_pos = slice(h_start, (h_start + KH_D), DH), \
slice(w_start, (w_start + KW_D), DW)
# [N, G, Cin//G, KH, KW], [G, Cout//G, Cin//G, KH, KW] -> [N, G, Cout//G] -> [N, Cout]
res.append(torch.einsum("abcde,bfcde->abf", x[:, :, :, h_pos, w_pos], weight))
res = torch.stack(res, dim=-1).view(N, Cout, Hout, Wout)
if bias is not None:
res.add_(bias[None, :, None, None])
return res
if __name__ == "__main__":
torch.backends.cudnn.deterministic = True
x = torch.randn(16, 128, 112, 112, device="cuda")
w = torch.randn(256, 128, 3, 3, device="cuda")
b = torch.randn(256, device="cuda")
y1 = F.conv2d(x, w, b, (1, 1), (1, 1), (2, 2), 1)
y2 = conv2d(x, w, b, (1, 1), (1, 1), (2, 2), 1)
y3 = conv2d_2(x, w, b, (1, 1), (1, 1), (2, 2), 1)
print(torch.allclose(y1, y2, atol=1e-3))
print(torch.allclose(y2, y3, atol=1e-3))
x = torch.randn(16, 128, 112, 112, device="cuda")
w = torch.randn(256, 1, 3, 3, device="cuda")
b = torch.randn(256, device="cuda")
y1 = F.conv2d(x, w, b, (1, 1), (1, 1), (2, 2), 128)
y2 = conv2d(x, w, b, (1, 1), (1, 1), (2, 2), 128)
y3 = conv2d_2(x, w, b, (1, 1), (1, 1), (2, 2), 128)
print(torch.allclose(y1, y2, atol=1e-3))
print(torch.allclose(y1, y3, atol=1e-3))