import torch
from torch import nn
class CNN(nn.Module):
def __init__(self, scale_factor, num_channels=1, d=56, s=12, m=4):
super(CNN, self).__init__()
self.first_part = nn.Sequential(
nn.Conv2d(num_channels, d, kernel_size=3, padding=5//2),
# nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=5//2),
# nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=5//2),
nn.PReLU(d)
)
self._initialize_weights()
def _initialize_weights(self):
for m in self.first_part:
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel())))
nn.init.zeros_(m.bias.data)
def forward(self, x):
x1 = self.first_part(x)
return x
model = CNN()
input = torch.randn(1,1,28,36)
pre = model(input)
print(pre)
print(pre.data)
print(pre.data[0][0][1][1].item())
print(pre.data.size())
print(pre.data.shape)
print(pre.shape)
print(pre.size())
结果:
tensor([[[[ 6.6364e-02, -1.3303e-04, -3.5543e-02, ..., -6.1688e-02,
3.1265e-01, -8.2845e-03],
[-3.5500e-02, -4.9340e-02, 1.7163e-01, ..., 1.3813e-01,
-8.2757e-03, 8.5905e-03],
[ 8.1270e-02, -2.7835e-03, -2.5063e-02, ..., 1.3616e-01,
-6.6608e-02, -6.2187e-03],
...,
[ 4.4998e-02, -4.2893e-02, 2.1688e-01, ..., 3.1650e-01,
-2.5981e-02, 2.6824e-03],
[ 4.0486e-03, 1.2759e-02, 3.1908e-01, ..., 2.2530e-01,
-4.8846e-02, 3.9362e-02],
[ 8.3523e-03, 1.5983e-02, 6.9260e-02, ..., -2.0295e-02,
-2.0337e-02, 3.9436e-02]]]], grad_fn=)
tensor([[[[ 6.6364e-02, -1.3303e-04, -3.5543e-02, ..., -6.1688e-02,
3.1265e-01, -8.2845e-03],
[-3.5500e-02, -4.9340e-02, 1.7163e-01, ..., 1.3813e-01,
-8.2757e-03, 8.5905e-03],
[ 8.1270e-02, -2.7835e-03, -2.5063e-02, ..., 1.3616e-01,
-6.6608e-02, -6.2187e-03],
...,
[ 4.4998e-02, -4.2893e-02, 2.1688e-01, ..., 3.1650e-01,
-2.5981e-02, 2.6824e-03],
[ 4.0486e-03, 1.2759e-02, 3.1908e-01, ..., 2.2530e-01,
-4.8846e-02, 3.9362e-02],
[ 8.3523e-03, 1.5983e-02, 6.9260e-02, ..., -2.0295e-02,
-2.0337e-02, 3.9436e-02]]]])
-0.04934029281139374
torch.Size([1, 56, 30, 38]) #计算过程30 = (28-3+2x2)/1 + 1, 38 = (36-3+2x2)/1+1
torch.Size([1, 56, 30, 38])
torch.Size([1, 56, 30, 38])
torch.Size([1, 56, 30, 38])
计算公式:
# nn.Conv2d(num_channels, d, kernel_size=(1,3), padding=3//2)方法
model = CNN(scale_factor = 3)
input = torch.randn(1,1,28,36)
pre = model(input)
print(pre)
# print(pre.data)
print(pre.data[0][0])
print(pre.data[0][0][1][1].item())
print(pre.data[0][0][0])
print(pre.data[0][0][1])
print(pre.data[0][0][2])
print(pre.data[0][0][3])
print(pre.data[0][0][4])
print(pre.data[0][0][5])
print(pre.data[0][0][6])
print(pre.data[0][0][-1])
print(pre.data.size())
print(pre.data.shape)
print(pre.shape)
print(pre.size())
结果:
tensor([[[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[-0.0315, 0.2222, -0.2953, ..., -0.7758, -0.6119, -0.2884],
[ 0.0894, 0.5969, 0.0282, ..., -0.2781, -0.3415, -0.1789],
...,
[-0.1551, 0.2289, 0.3626, ..., -0.1653, 0.3931, 0.3321],
[-0.4042, -0.1149, 0.0374, ..., 0.6755, 0.1891, -0.0073],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[-0.0212, 0.0645, -0.0238, ..., 0.1489, 0.1142, -0.1058],
[-0.0321, -0.0747, -0.0549, ..., 0.1178, 0.0921, -0.0454],
...,
[ 0.0469, -0.0927, -0.1170, ..., 0.2002, -0.2892, -0.0782],
[ 0.1105, -0.0514, -0.0808, ..., -0.1853, 0.1171, 0.1068],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.0835, -0.1424, 0.1503, ..., 0.1009, -0.0085, 0.1033],
[ 0.0232, -0.1340, 0.1330, ..., 0.0130, -0.0159, 0.0469],
...,
[-0.0155, -0.0275, 0.0204, ..., -0.1981, 0.0670, 0.0502],
[-0.0070, 0.0174, 0.0540, ..., 0.0298, -0.0567, -0.0900],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
...,
[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.1037, 0.0157, 0.0510, ..., 0.0336, -0.1710, -0.2514],
[ 0.0334, -0.0469, 0.1680, ..., 0.1241, -0.0480, -0.1418],
...,
[-0.0269, -0.1274, 0.0102, ..., -0.0414, -0.3004, 0.1501],
[-0.0279, -0.1744, -0.0603, ..., 0.0836, 0.3084, 0.0699],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[-0.0355, 0.1285, -0.1310, ..., -0.1803, -0.1289, -0.1647],
[ 0.0068, 0.1698, -0.0434, ..., -0.0262, -0.0548, -0.0884],
...,
[-0.0207, 0.0267, 0.0473, ..., 0.0895, -0.0434, 0.0534],
[-0.0661, -0.0665, -0.0409, ..., 0.1070, 0.1317, 0.0704],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]],
[[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[ 0.2388, -0.1572, 0.2690, ..., 0.2536, -0.1721, -0.2086],
[ 0.0604, -0.2845, 0.3744, ..., 0.2247, -0.0409, -0.1276],
...,
[-0.0345, -0.2383, -0.0104, ..., -0.2471, -0.4034, 0.2226],
[ 0.0049, -0.2131, -0.0328, ..., 0.0631, 0.3511, 0.0043],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]]])
tensor([[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000],
[-0.0315, 0.2222, -0.2953, ..., -0.7758, -0.6119, -0.2884],
[ 0.0894, 0.5969, 0.0282, ..., -0.2781, -0.3415, -0.1789],
...,
[-0.1551, 0.2289, 0.3626, ..., -0.1653, 0.3931, 0.3321],
[-0.4042, -0.1149, 0.0374, ..., 0.6755, 0.1891, -0.0073],
[ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]) 上下的0.0000是padding填充的结果
0.22215284407138824
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([-0.0315, 0.2222, -0.2953, 0.1427, -0.2313, -0.5308, -0.2953, -0.2252,
0.2495, -0.0302, 0.1041, -0.5481, -0.0725, -0.0198, 0.2743, 0.4091,
-0.1611, -0.2999, -0.2310, 0.0428, -0.2597, 0.0626, -0.0776, 0.1501,
0.0440, -0.3942, -0.0305, 0.0865, 0.1754, -0.0359, 0.0967, -0.0470,
-0.2386, -0.7758, -0.6119, -0.2884])
tensor([ 0.0894, 0.5969, 0.0282, -0.0242, -0.1550, 0.1429, 0.0045, -0.4359,
-0.0452, 0.1667, 0.4410, 0.2647, 0.2186, -0.0451, 0.5798, 0.4195,
-0.2234, 0.0881, -0.0903, 0.0221, 0.2936, 0.0059, 0.2871, -0.2773,
0.1871, 0.0105, 0.1898, -0.0319, 0.1036, 0.1972, 0.1958, 0.2907,
0.1177, -0.2781, -0.3415, -0.1789])
tensor([ 0.5702, -0.0570, -0.5150, -0.3509, 0.2484, 0.5581, -0.0588, -0.0200,
-0.0346, 0.2426, -0.2978, 0.0562, -0.1117, -0.3938, -0.2345, -0.2661,
0.2033, -0.2472, 0.2124, -0.1569, -0.4880, -0.8923, -0.2910, 0.1807,
0.1235, 0.4139, 0.0701, -0.1217, 0.3566, 0.5507, 0.3293, 0.2777,
0.2855, 0.1828, 0.2574, 0.1168])
tensor([ 0.1693, 0.1578, -0.1188, -0.1728, -0.3290, 0.2469, 0.5145, 0.6355,
0.2808, -0.6514, -0.5881, 0.0906, -0.3123, 0.5027, 0.2909, 0.6274,
1.2742, 1.1023, 0.3220, 0.2541, -0.3775, -0.1256, -0.3329, 0.0644,
0.1964, 0.2278, 0.2365, 0.5741, 0.3289, 0.6638, 0.1729, -0.2596,
-0.1766, 0.1998, 0.4604, 0.2442])
tensor([ 0.0660, -0.3376, -0.1423, -0.2701, -0.0633, 0.1919, 0.4765, 0.2492,
0.0457, 0.0720, 0.3947, 0.1662, -0.0149, -0.3476, -0.4126, -0.5365,
-0.3285, 0.2710, 0.6813, 0.8570, -0.0067, -0.4599, 0.6223, 0.6125,
0.1302, -0.3530, -0.6021, 0.0301, -0.1814, 0.1297, -0.2597, -0.2491,
-0.5377, -0.2454, -0.4172, -0.1774])
tensor([-0.0815, -0.2279, -0.1583, -0.3918, -0.3411, -0.5401, -0.1437, 0.1512,
0.1818, 0.3640, 0.1382, -0.1335, -0.4312, -0.6804, -0.3305, 0.0529,
0.0879, 0.4044, 0.3641, 0.3695, 0.2838, 0.0965, 0.0414, 0.3015,
0.0043, 0.5818, 0.3425, -0.1442, -0.0725, -0.0736, 0.1090, -0.2067,
-0.0579, 0.0168, 0.3162, 0.1692])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
torch.Size([1, 56, 30, 36])
torch.Size([1, 56, 30, 36])
torch.Size([1, 56, 30, 36])
torch.Size([1, 56, 30, 36])
# nn.Conv2d(num_channels, d, kernel_size=(3,1), padding=3//2)方法
结果:
tensor([[[[ 0.0000, -0.2665, 0.0052, ..., -0.0663, 0.2429, 0.0000],
[ 0.0000, -0.1602, -0.1386, ..., -0.1798, 0.4090, 0.0000],
[ 0.0000, 0.2891, -0.2059, ..., -0.2856, 0.1771, 0.0000],
...,
[ 0.0000, -0.2019, -0.1313, ..., 0.1987, -0.0560, 0.0000],
[ 0.0000, 0.0348, -0.0470, ..., 0.1345, -0.0694, 0.0000],
[ 0.0000, 0.0484, -0.0343, ..., 0.1832, 0.1907, 0.0000]],
[[ 0.0000, 0.2650, -0.0029, ..., 0.0679, -0.2456, 0.0000],
[ 0.0000, -0.0842, 0.1587, ..., 0.1347, -0.2150, 0.0000],
[ 0.0000, -0.2408, 0.0732, ..., 0.1868, 0.0150, 0.0000],
...,
[ 0.0000, 0.0350, 0.0971, ..., -0.0868, 0.1741, 0.0000],
[ 0.0000, -0.0737, -0.0437, ..., -0.0650, -0.0956, 0.0000],
[ 0.0000, 0.0191, 0.0822, ..., -0.1411, -0.1188, 0.0000]],
[[ 0.0000, 0.0656, -0.0032, ..., 0.0147, -0.0564, 0.0000],
[ 0.0000, -0.2945, 0.0594, ..., -0.0252, 0.1706, 0.0000],
[ 0.0000, -0.0024, -0.1312, ..., -0.0675, 0.2218, 0.0000],
...,
[ 0.0000, -0.1813, -0.0136, ..., 0.1093, 0.1821, 0.0000],
[ 0.0000, -0.0638, -0.1183, ..., 0.0690, -0.2098, 0.0000],
[ 0.0000, 0.0825, 0.0755, ..., 0.0136, 0.0535, 0.0000]],
...,
[[ 0.0000, -0.1202, 0.0746, ..., 0.0314, -0.0210, 0.0000],
[ 0.0000, -0.0126, -0.0183, ..., 0.0611, 0.0161, 0.0000],
[ 0.0000, 0.0497, -0.0397, ..., -0.0380, -0.0153, 0.0000],
...,
[ 0.0000, 0.0016, -0.0508, ..., -0.0282, -0.1388, 0.0000],
[ 0.0000, 0.0477, 0.0852, ..., -0.0699, 0.0255, 0.0000],
[ 0.0000, -0.0278, -0.0439, ..., 0.0365, 0.0193, 0.0000]],
[[ 0.0000, -0.1411, 0.0150, ..., -0.0247, 0.1065, 0.0000],
[ 0.0000, -0.3459, -0.0443, ..., -0.1247, 0.4057, 0.0000],
[ 0.0000, 0.1938, -0.2472, ..., -0.2463, 0.2921, 0.0000],
...,
[ 0.0000, -0.2777, -0.1054, ..., 0.2118, 0.0860, 0.0000],
[ 0.0000, -0.0205, -0.1105, ..., 0.1296, -0.2110, 0.0000],
[ 0.0000, 0.0932, 0.0311, ..., 0.1361, 0.1702, 0.0000]],
[[ 0.0000, -0.1253, 0.2331, ..., 0.1646, -0.3026, 0.0000],
[ 0.0000, -0.5238, 0.1267, ..., 0.2429, 0.1681, 0.0000],
[ 0.0000, 0.0067, -0.2741, ..., -0.1049, 0.2849, 0.0000],
...,
[ 0.0000, -0.2385, -0.1226, ..., 0.0156, -0.0712, 0.0000],
[ 0.0000, 0.0142, 0.0732, ..., -0.1638, -0.2847, 0.0000],
[ 0.0000, 0.0433, 0.0207, ..., 0.0494, 0.0665, 0.0000]]]])
tensor([[ 0.0000, -0.2665, 0.0052, ..., -0.0663, 0.2429, 0.0000],
[ 0.0000, -0.1602, -0.1386, ..., -0.1798, 0.4090, 0.0000],
[ 0.0000, 0.2891, -0.2059, ..., -0.2856, 0.1771, 0.0000],
...,
[ 0.0000, -0.2019, -0.1313, ..., 0.1987, -0.0560, 0.0000],
[ 0.0000, 0.0348, -0.0470, ..., 0.1345, -0.0694, 0.0000],
[ 0.0000, 0.0484, -0.0343, ..., 0.1832, 0.1907, 0.0000]])
-0.16020259261131287
tensor([ 0.0000, -0.2665, 0.0052, -0.0403, 0.0104, -0.1253, -0.0612, 0.0583,
-0.0431, 0.0037, -0.1454, 0.0387, -0.0086, 0.0119, 0.0965, 0.1796,
-0.0665, -0.1207, -0.1479, -0.0218, 0.1017, 0.3104, -0.1378, -0.0020,
-0.1307, -0.2126, -0.0714, 0.1420, -0.0902, 0.2329, 0.0456, 0.0767,
0.0070, 0.0960, 0.0613, -0.0663, 0.2429, 0.0000])
tensor([ 0.0000e+00, -1.6020e-01, -1.3857e-01, 4.4766e-02, 3.6776e-04,
-4.8631e-02, -1.4214e-01, -1.0698e-01, -1.0469e-01, -6.2759e-02,
-2.5928e-01, 1.4373e-02, 2.5133e-02, 5.7116e-03, 1.3167e-01,
3.3955e-01, 1.2455e-01, 6.8688e-03, -3.4117e-01, 2.0203e-02,
-1.1731e-01, 3.3134e-01, -1.1501e-01, -1.5608e-02, -1.6398e-01,
-2.8882e-01, -1.7618e-02, 2.4095e-01, -2.3142e-01, 2.1594e-01,
-1.6275e-02, 1.9483e-01, 8.6949e-02, 1.3403e-01, 1.2124e-01,
-1.7975e-01, 4.0899e-01, 0.0000e+00])
tensor([ 0.0000e+00, 2.8914e-01, -2.0587e-01, 2.5269e-01, 1.2159e-01,
5.8757e-02, -2.5931e-01, -2.5056e-01, -1.2442e-01, -1.2865e-01,
-3.4671e-05, -1.5588e-01, 2.7979e-02, 7.4262e-02, 3.3392e-02,
1.9631e-01, 2.6568e-01, 4.8094e-02, -2.2988e-01, -7.3294e-03,
-2.6656e-01, 5.9278e-02, -3.6376e-02, -6.8909e-02, -7.3849e-02,
3.8072e-02, 1.5668e-01, 3.2020e-01, -4.0010e-01, -6.2959e-02,
4.6597e-02, 2.0234e-01, 1.4463e-01, 8.6147e-03, -1.4288e-03,
-2.8556e-01, 1.7709e-01, 0.0000e+00])
tensor([ 0.0000, 0.2429, 0.0210, 0.3643, 0.2347, -0.0319, -0.2282, -0.2728,
0.2310, -0.3252, 0.1110, 0.0297, -0.0811, 0.0235, -0.0734, 0.0109,
0.0639, -0.0535, 0.1225, 0.0271, 0.0277, 0.1528, -0.1052, -0.1419,
0.0670, 0.0415, 0.1639, 0.3170, -0.1593, 0.0248, 0.3244, -0.0540,
0.0971, 0.1045, 0.1890, -0.1563, -0.1246, 0.0000])
tensor([ 0.0000, -0.0021, -0.0117, -0.0706, 0.3629, 0.0901, -0.0349, -0.1799,
0.2409, -0.2738, 0.0473, 0.2873, -0.2413, 0.0266, -0.2500, 0.0446,
0.0846, 0.1752, 0.1120, 0.1048, 0.2087, 0.0501, -0.1243, -0.1623,
0.2456, -0.2475, 0.1696, 0.0806, 0.1081, 0.1978, 0.0715, -0.0433,
0.0588, 0.2438, 0.2774, 0.0420, -0.2895, 0.0000])
tensor([ 0.0000, 0.0678, -0.0804, -0.1645, 0.0288, 0.1927, -0.0944, 0.1028,
0.0415, 0.1267, -0.0689, 0.1832, -0.0494, 0.0267, -0.2153, 0.1689,
0.1709, 0.1118, -0.0311, 0.2397, -0.0135, -0.1035, 0.1029, 0.0398,
0.1069, -0.1409, 0.1346, -0.2777, 0.0086, 0.0522, -0.1968, 0.2060,
-0.0040, -0.0089, -0.0507, 0.0334, -0.1372, 0.0000])
tensor([ 0.0000e+00, 1.8976e-02, -1.9297e-02, 3.5077e-01, -8.7172e-02,
1.7590e-01, -1.0609e-01, 3.2915e-01, -3.1518e-03, 1.7384e-01,
-5.0477e-02, -3.3394e-02, -3.5567e-04, -9.7546e-02, -1.5968e-01,
1.0812e-01, 1.9788e-02, -1.7623e-02, -3.1855e-02, 3.7437e-01,
-2.5711e-01, 2.5959e-02, 2.4762e-01, 1.3323e-01, -4.1156e-02,
-3.4118e-02, 1.7522e-03, -1.9903e-01, -6.9717e-02, -4.9169e-02,
-1.8395e-01, 4.8062e-02, 1.3727e-01, -6.3850e-02, -1.5100e-01,
-2.7803e-01, 2.4404e-01, 0.0000e+00])
tensor([ 0.0000, 0.0484, -0.0343, -0.0134, -0.0573, 0.1453, -0.1117, 0.1103,
-0.3307, 0.1104, 0.2315, -0.0822, -0.0464, 0.1169, -0.0332, -0.3909,
0.1107, -0.0350, -0.1648, -0.0221, 0.2701, 0.0953, 0.0496, 0.1484,
0.0852, 0.1883, 0.1308, -0.2373, 0.0845, -0.0580, 0.0799, -0.1907,
0.0123, -0.0902, -0.1832, 0.1832, 0.1907, 0.0000])
torch.Size([1, 56, 28, 38])
torch.Size([1, 56, 28, 38])
torch.Size([1, 56, 28, 38])
torch.Size([1, 56, 28, 38])