pytorch 卷积计算

import torch
from torch import nn

class CNN(nn.Module):
    def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
        super(CNN, self).__init__()
        self.first_part = nn.Sequential(
            nn.Conv2d(num_channels, d, kernel_size=3, padding=5//2),
            # nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=5//2),
            # nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=5//2),
            nn.PReLU(d)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.first_part:
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
                nn.init.zeros_(m.bias.data)

    def forward(self, x):
        x1 = self.first_part(x)
        return x

model = CNN()
input = torch.randn(1,1,28,36)
pre = model(input)
print(pre)
print(pre.data)
print(pre.data[0][0][1][1].item())
print(pre.data.size())
print(pre.data.shape)
print(pre.shape)
print(pre.size())

结果:
tensor([[[[ 6.6364e-02, -1.3303e-04, -3.5543e-02,  ..., -6.1688e-02,
            3.1265e-01, -8.2845e-03],
          [-3.5500e-02, -4.9340e-02,  1.7163e-01,  ...,  1.3813e-01,
           -8.2757e-03,  8.5905e-03],
          [ 8.1270e-02, -2.7835e-03, -2.5063e-02,  ...,  1.3616e-01,
           -6.6608e-02, -6.2187e-03],
          ...,
          [ 4.4998e-02, -4.2893e-02,  2.1688e-01,  ...,  3.1650e-01,
           -2.5981e-02,  2.6824e-03],
          [ 4.0486e-03,  1.2759e-02,  3.1908e-01,  ...,  2.2530e-01,
           -4.8846e-02,  3.9362e-02],
          [ 8.3523e-03,  1.5983e-02,  6.9260e-02,  ..., -2.0295e-02,
           -2.0337e-02,  3.9436e-02]]]], grad_fn=)
tensor([[[[ 6.6364e-02, -1.3303e-04, -3.5543e-02,  ..., -6.1688e-02,
            3.1265e-01, -8.2845e-03],
          [-3.5500e-02, -4.9340e-02,  1.7163e-01,  ...,  1.3813e-01,
           -8.2757e-03,  8.5905e-03],
          [ 8.1270e-02, -2.7835e-03, -2.5063e-02,  ...,  1.3616e-01,
           -6.6608e-02, -6.2187e-03],
          ...,
          [ 4.4998e-02, -4.2893e-02,  2.1688e-01,  ...,  3.1650e-01,
           -2.5981e-02,  2.6824e-03],
          [ 4.0486e-03,  1.2759e-02,  3.1908e-01,  ...,  2.2530e-01,
           -4.8846e-02,  3.9362e-02],
          [ 8.3523e-03,  1.5983e-02,  6.9260e-02,  ..., -2.0295e-02,
           -2.0337e-02,  3.9436e-02]]]])
-0.04934029281139374
torch.Size([1, 56, 30, 38]) #计算过程30 = (28-3+2x2)/1 + 1, 38 = (36-3+2x2)/1+1
torch.Size([1, 56, 30, 38])
torch.Size([1, 56, 30, 38])
torch.Size([1, 56, 30, 38])

计算公式:

 

# nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=3//2)方法
model = CNN(scale_factor = 3)
input = torch.randn(1,1,28,36)
pre = model(input)
print(pre)
# print(pre.data)
print(pre.data[0][0])
print(pre.data[0][0][1][1].item())
print(pre.data[0][0][0])
print(pre.data[0][0][1])
print(pre.data[0][0][2])
print(pre.data[0][0][3])
print(pre.data[0][0][4])
print(pre.data[0][0][5])
print(pre.data[0][0][6])
print(pre.data[0][0][-1])
print(pre.data.size())
print(pre.data.shape)
print(pre.shape)
print(pre.size())

结果:
tensor([[[[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [-0.0315,  0.2222, -0.2953,  ..., -0.7758, -0.6119, -0.2884],
          [ 0.0894,  0.5969,  0.0282,  ..., -0.2781, -0.3415, -0.1789],
          ...,
          [-0.1551,  0.2289,  0.3626,  ..., -0.1653,  0.3931,  0.3321],
          [-0.4042, -0.1149,  0.0374,  ...,  0.6755,  0.1891, -0.0073],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [-0.0212,  0.0645, -0.0238,  ...,  0.1489,  0.1142, -0.1058],
          [-0.0321, -0.0747, -0.0549,  ...,  0.1178,  0.0921, -0.0454],
          ...,
          [ 0.0469, -0.0927, -0.1170,  ...,  0.2002, -0.2892, -0.0782],
          [ 0.1105, -0.0514, -0.0808,  ..., -0.1853,  0.1171,  0.1068],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.0835, -0.1424,  0.1503,  ...,  0.1009, -0.0085,  0.1033],
          [ 0.0232, -0.1340,  0.1330,  ...,  0.0130, -0.0159,  0.0469],
          ...,
          [-0.0155, -0.0275,  0.0204,  ..., -0.1981,  0.0670,  0.0502],
          [-0.0070,  0.0174,  0.0540,  ...,  0.0298, -0.0567, -0.0900],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

         ...,

         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.1037,  0.0157,  0.0510,  ...,  0.0336, -0.1710, -0.2514],
          [ 0.0334, -0.0469,  0.1680,  ...,  0.1241, -0.0480, -0.1418],
          ...,
          [-0.0269, -0.1274,  0.0102,  ..., -0.0414, -0.3004,  0.1501],
          [-0.0279, -0.1744, -0.0603,  ...,  0.0836,  0.3084,  0.0699],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [-0.0355,  0.1285, -0.1310,  ..., -0.1803, -0.1289, -0.1647],
          [ 0.0068,  0.1698, -0.0434,  ..., -0.0262, -0.0548, -0.0884],
          ...,
          [-0.0207,  0.0267,  0.0473,  ...,  0.0895, -0.0434,  0.0534],
          [-0.0661, -0.0665, -0.0409,  ...,  0.1070,  0.1317,  0.0704],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

         [[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
          [ 0.2388, -0.1572,  0.2690,  ...,  0.2536, -0.1721, -0.2086],
          [ 0.0604, -0.2845,  0.3744,  ...,  0.2247, -0.0409, -0.1276],
          ...,
          [-0.0345, -0.2383, -0.0104,  ..., -0.2471, -0.4034,  0.2226],
          [ 0.0049, -0.2131, -0.0328,  ...,  0.0631,  0.3511,  0.0043],
          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]]])
tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0315,  0.2222, -0.2953,  ..., -0.7758, -0.6119, -0.2884],
        [ 0.0894,  0.5969,  0.0282,  ..., -0.2781, -0.3415, -0.1789],
        ...,
        [-0.1551,  0.2289,  0.3626,  ..., -0.1653,  0.3931,  0.3321],
        [-0.4042, -0.1149,  0.0374,  ...,  0.6755,  0.1891, -0.0073],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]) 上下的0.0000是padding填充的结果
0.22215284407138824
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([-0.0315,  0.2222, -0.2953,  0.1427, -0.2313, -0.5308, -0.2953, -0.2252,
         0.2495, -0.0302,  0.1041, -0.5481, -0.0725, -0.0198,  0.2743,  0.4091,
        -0.1611, -0.2999, -0.2310,  0.0428, -0.2597,  0.0626, -0.0776,  0.1501,
         0.0440, -0.3942, -0.0305,  0.0865,  0.1754, -0.0359,  0.0967, -0.0470,
        -0.2386, -0.7758, -0.6119, -0.2884])
tensor([ 0.0894,  0.5969,  0.0282, -0.0242, -0.1550,  0.1429,  0.0045, -0.4359,
        -0.0452,  0.1667,  0.4410,  0.2647,  0.2186, -0.0451,  0.5798,  0.4195,
        -0.2234,  0.0881, -0.0903,  0.0221,  0.2936,  0.0059,  0.2871, -0.2773,
         0.1871,  0.0105,  0.1898, -0.0319,  0.1036,  0.1972,  0.1958,  0.2907,
         0.1177, -0.2781, -0.3415, -0.1789])
tensor([ 0.5702, -0.0570, -0.5150, -0.3509,  0.2484,  0.5581, -0.0588, -0.0200,
        -0.0346,  0.2426, -0.2978,  0.0562, -0.1117, -0.3938, -0.2345, -0.2661,
         0.2033, -0.2472,  0.2124, -0.1569, -0.4880, -0.8923, -0.2910,  0.1807,
         0.1235,  0.4139,  0.0701, -0.1217,  0.3566,  0.5507,  0.3293,  0.2777,
         0.2855,  0.1828,  0.2574,  0.1168])
tensor([ 0.1693,  0.1578, -0.1188, -0.1728, -0.3290,  0.2469,  0.5145,  0.6355,
         0.2808, -0.6514, -0.5881,  0.0906, -0.3123,  0.5027,  0.2909,  0.6274,
         1.2742,  1.1023,  0.3220,  0.2541, -0.3775, -0.1256, -0.3329,  0.0644,
         0.1964,  0.2278,  0.2365,  0.5741,  0.3289,  0.6638,  0.1729, -0.2596,
        -0.1766,  0.1998,  0.4604,  0.2442])
tensor([ 0.0660, -0.3376, -0.1423, -0.2701, -0.0633,  0.1919,  0.4765,  0.2492,
         0.0457,  0.0720,  0.3947,  0.1662, -0.0149, -0.3476, -0.4126, -0.5365,
        -0.3285,  0.2710,  0.6813,  0.8570, -0.0067, -0.4599,  0.6223,  0.6125,
         0.1302, -0.3530, -0.6021,  0.0301, -0.1814,  0.1297, -0.2597, -0.2491,
        -0.5377, -0.2454, -0.4172, -0.1774])
tensor([-0.0815, -0.2279, -0.1583, -0.3918, -0.3411, -0.5401, -0.1437,  0.1512,
         0.1818,  0.3640,  0.1382, -0.1335, -0.4312, -0.6804, -0.3305,  0.0529,
         0.0879,  0.4044,  0.3641,  0.3695,  0.2838,  0.0965,  0.0414,  0.3015,
         0.0043,  0.5818,  0.3425, -0.1442, -0.0725, -0.0736,  0.1090, -0.2067,
        -0.0579,  0.0168,  0.3162,  0.1692])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
torch.Size([1, 56, 30, 36])
torch.Size([1, 56, 30, 36])
torch.Size([1, 56, 30, 36])
torch.Size([1, 56, 30, 36])


# nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=3//2)方法
结果:
tensor([[[[ 0.0000, -0.2665,  0.0052,  ..., -0.0663,  0.2429,  0.0000],
          [ 0.0000, -0.1602, -0.1386,  ..., -0.1798,  0.4090,  0.0000],
          [ 0.0000,  0.2891, -0.2059,  ..., -0.2856,  0.1771,  0.0000],
          ...,
          [ 0.0000, -0.2019, -0.1313,  ...,  0.1987, -0.0560,  0.0000],
          [ 0.0000,  0.0348, -0.0470,  ...,  0.1345, -0.0694,  0.0000],
          [ 0.0000,  0.0484, -0.0343,  ...,  0.1832,  0.1907,  0.0000]],

         [[ 0.0000,  0.2650, -0.0029,  ...,  0.0679, -0.2456,  0.0000],
          [ 0.0000, -0.0842,  0.1587,  ...,  0.1347, -0.2150,  0.0000],
          [ 0.0000, -0.2408,  0.0732,  ...,  0.1868,  0.0150,  0.0000],
          ...,
          [ 0.0000,  0.0350,  0.0971,  ..., -0.0868,  0.1741,  0.0000],
          [ 0.0000, -0.0737, -0.0437,  ..., -0.0650, -0.0956,  0.0000],
          [ 0.0000,  0.0191,  0.0822,  ..., -0.1411, -0.1188,  0.0000]],

         [[ 0.0000,  0.0656, -0.0032,  ...,  0.0147, -0.0564,  0.0000],
          [ 0.0000, -0.2945,  0.0594,  ..., -0.0252,  0.1706,  0.0000],
          [ 0.0000, -0.0024, -0.1312,  ..., -0.0675,  0.2218,  0.0000],
          ...,
          [ 0.0000, -0.1813, -0.0136,  ...,  0.1093,  0.1821,  0.0000],
          [ 0.0000, -0.0638, -0.1183,  ...,  0.0690, -0.2098,  0.0000],
          [ 0.0000,  0.0825,  0.0755,  ...,  0.0136,  0.0535,  0.0000]],

         ...,

         [[ 0.0000, -0.1202,  0.0746,  ...,  0.0314, -0.0210,  0.0000],
          [ 0.0000, -0.0126, -0.0183,  ...,  0.0611,  0.0161,  0.0000],
          [ 0.0000,  0.0497, -0.0397,  ..., -0.0380, -0.0153,  0.0000],
          ...,
          [ 0.0000,  0.0016, -0.0508,  ..., -0.0282, -0.1388,  0.0000],
          [ 0.0000,  0.0477,  0.0852,  ..., -0.0699,  0.0255,  0.0000],
          [ 0.0000, -0.0278, -0.0439,  ...,  0.0365,  0.0193,  0.0000]],

         [[ 0.0000, -0.1411,  0.0150,  ..., -0.0247,  0.1065,  0.0000],
          [ 0.0000, -0.3459, -0.0443,  ..., -0.1247,  0.4057,  0.0000],
          [ 0.0000,  0.1938, -0.2472,  ..., -0.2463,  0.2921,  0.0000],
          ...,
          [ 0.0000, -0.2777, -0.1054,  ...,  0.2118,  0.0860,  0.0000],
          [ 0.0000, -0.0205, -0.1105,  ...,  0.1296, -0.2110,  0.0000],
          [ 0.0000,  0.0932,  0.0311,  ...,  0.1361,  0.1702,  0.0000]],

         [[ 0.0000, -0.1253,  0.2331,  ...,  0.1646, -0.3026,  0.0000],
          [ 0.0000, -0.5238,  0.1267,  ...,  0.2429,  0.1681,  0.0000],
          [ 0.0000,  0.0067, -0.2741,  ..., -0.1049,  0.2849,  0.0000],
          ...,
          [ 0.0000, -0.2385, -0.1226,  ...,  0.0156, -0.0712,  0.0000],
          [ 0.0000,  0.0142,  0.0732,  ..., -0.1638, -0.2847,  0.0000],
          [ 0.0000,  0.0433,  0.0207,  ...,  0.0494,  0.0665,  0.0000]]]])
tensor([[ 0.0000, -0.2665,  0.0052,  ..., -0.0663,  0.2429,  0.0000],
        [ 0.0000, -0.1602, -0.1386,  ..., -0.1798,  0.4090,  0.0000],
        [ 0.0000,  0.2891, -0.2059,  ..., -0.2856,  0.1771,  0.0000],
        ...,
        [ 0.0000, -0.2019, -0.1313,  ...,  0.1987, -0.0560,  0.0000],
        [ 0.0000,  0.0348, -0.0470,  ...,  0.1345, -0.0694,  0.0000],
        [ 0.0000,  0.0484, -0.0343,  ...,  0.1832,  0.1907,  0.0000]])
-0.16020259261131287
tensor([ 0.0000, -0.2665,  0.0052, -0.0403,  0.0104, -0.1253, -0.0612,  0.0583,
        -0.0431,  0.0037, -0.1454,  0.0387, -0.0086,  0.0119,  0.0965,  0.1796,
        -0.0665, -0.1207, -0.1479, -0.0218,  0.1017,  0.3104, -0.1378, -0.0020,
        -0.1307, -0.2126, -0.0714,  0.1420, -0.0902,  0.2329,  0.0456,  0.0767,
         0.0070,  0.0960,  0.0613, -0.0663,  0.2429,  0.0000])
tensor([ 0.0000e+00, -1.6020e-01, -1.3857e-01,  4.4766e-02,  3.6776e-04,
        -4.8631e-02, -1.4214e-01, -1.0698e-01, -1.0469e-01, -6.2759e-02,
        -2.5928e-01,  1.4373e-02,  2.5133e-02,  5.7116e-03,  1.3167e-01,
         3.3955e-01,  1.2455e-01,  6.8688e-03, -3.4117e-01,  2.0203e-02,
        -1.1731e-01,  3.3134e-01, -1.1501e-01, -1.5608e-02, -1.6398e-01,
        -2.8882e-01, -1.7618e-02,  2.4095e-01, -2.3142e-01,  2.1594e-01,
        -1.6275e-02,  1.9483e-01,  8.6949e-02,  1.3403e-01,  1.2124e-01,
        -1.7975e-01,  4.0899e-01,  0.0000e+00])
tensor([ 0.0000e+00,  2.8914e-01, -2.0587e-01,  2.5269e-01,  1.2159e-01,
         5.8757e-02, -2.5931e-01, -2.5056e-01, -1.2442e-01, -1.2865e-01,
        -3.4671e-05, -1.5588e-01,  2.7979e-02,  7.4262e-02,  3.3392e-02,
         1.9631e-01,  2.6568e-01,  4.8094e-02, -2.2988e-01, -7.3294e-03,
        -2.6656e-01,  5.9278e-02, -3.6376e-02, -6.8909e-02, -7.3849e-02,
         3.8072e-02,  1.5668e-01,  3.2020e-01, -4.0010e-01, -6.2959e-02,
         4.6597e-02,  2.0234e-01,  1.4463e-01,  8.6147e-03, -1.4288e-03,
        -2.8556e-01,  1.7709e-01,  0.0000e+00])
tensor([ 0.0000,  0.2429,  0.0210,  0.3643,  0.2347, -0.0319, -0.2282, -0.2728,
         0.2310, -0.3252,  0.1110,  0.0297, -0.0811,  0.0235, -0.0734,  0.0109,
         0.0639, -0.0535,  0.1225,  0.0271,  0.0277,  0.1528, -0.1052, -0.1419,
         0.0670,  0.0415,  0.1639,  0.3170, -0.1593,  0.0248,  0.3244, -0.0540,
         0.0971,  0.1045,  0.1890, -0.1563, -0.1246,  0.0000])
tensor([ 0.0000, -0.0021, -0.0117, -0.0706,  0.3629,  0.0901, -0.0349, -0.1799,
         0.2409, -0.2738,  0.0473,  0.2873, -0.2413,  0.0266, -0.2500,  0.0446,
         0.0846,  0.1752,  0.1120,  0.1048,  0.2087,  0.0501, -0.1243, -0.1623,
         0.2456, -0.2475,  0.1696,  0.0806,  0.1081,  0.1978,  0.0715, -0.0433,
         0.0588,  0.2438,  0.2774,  0.0420, -0.2895,  0.0000])
tensor([ 0.0000,  0.0678, -0.0804, -0.1645,  0.0288,  0.1927, -0.0944,  0.1028,
         0.0415,  0.1267, -0.0689,  0.1832, -0.0494,  0.0267, -0.2153,  0.1689,
         0.1709,  0.1118, -0.0311,  0.2397, -0.0135, -0.1035,  0.1029,  0.0398,
         0.1069, -0.1409,  0.1346, -0.2777,  0.0086,  0.0522, -0.1968,  0.2060,
        -0.0040, -0.0089, -0.0507,  0.0334, -0.1372,  0.0000])
tensor([ 0.0000e+00,  1.8976e-02, -1.9297e-02,  3.5077e-01, -8.7172e-02,
         1.7590e-01, -1.0609e-01,  3.2915e-01, -3.1518e-03,  1.7384e-01,
        -5.0477e-02, -3.3394e-02, -3.5567e-04, -9.7546e-02, -1.5968e-01,
         1.0812e-01,  1.9788e-02, -1.7623e-02, -3.1855e-02,  3.7437e-01,
        -2.5711e-01,  2.5959e-02,  2.4762e-01,  1.3323e-01, -4.1156e-02,
        -3.4118e-02,  1.7522e-03, -1.9903e-01, -6.9717e-02, -4.9169e-02,
        -1.8395e-01,  4.8062e-02,  1.3727e-01, -6.3850e-02, -1.5100e-01,
        -2.7803e-01,  2.4404e-01,  0.0000e+00])
tensor([ 0.0000,  0.0484, -0.0343, -0.0134, -0.0573,  0.1453, -0.1117,  0.1103,
        -0.3307,  0.1104,  0.2315, -0.0822, -0.0464,  0.1169, -0.0332, -0.3909,
         0.1107, -0.0350, -0.1648, -0.0221,  0.2701,  0.0953,  0.0496,  0.1484,
         0.0852,  0.1883,  0.1308, -0.2373,  0.0845, -0.0580,  0.0799, -0.1907,
         0.0123, -0.0902, -0.1832,  0.1832,  0.1907,  0.0000])
torch.Size([1, 56, 28, 38])
torch.Size([1, 56, 28, 38])
torch.Size([1, 56, 28, 38])
torch.Size([1, 56, 28, 38])

 

你可能感兴趣的:(pytorch学习)