在forward内部不需要必须使用pytorch的内置函数,可以改成自己的内容,例如插值等:
例如如下的操作(使用张量):
x1 = torch.rand((1, 4, 3, 3))
x2 = torch.rand((1, 4, 3, 3))
x3 = torch.rand((1, 4, 3, 3))
x4 = torch.rand((1, 4, 3, 3))
得到的值是:
tensor([[[[0.1380, 0.9205, 0.1776],
[0.7586, 0.9378, 0.0803],
[0.1316, 0.4958, 0.4530]],
[[0.9914, 0.2687, 0.5410],
[0.5149, 0.6382, 0.1535],
[0.6628, 0.2382, 0.4334]],
[[0.7072, 0.7634, 0.0528],
[0.4097, 0.5194, 0.2898],
[0.2384, 0.0681, 0.0125]],
[[0.0090, 0.9438, 0.3352],
[0.8833, 0.1081, 0.6163],
[0.5972, 0.6735, 0.7644]]]])
tensor([[[[0.1295, 0.1420, 0.0031],
[0.6151, 0.8103, 0.5446],
[0.9363, 0.7179, 0.6631]],
[[0.5388, 0.4465, 0.6568],
[0.1802, 0.8209, 0.1948],
[0.0764, 0.3371, 0.1915]],
[[0.6165, 0.3070, 0.1122],
[0.0084, 0.7888, 0.5461],
[0.0845, 0.0765, 0.7305]],
[[0.9088, 0.5142, 0.0295],
[0.6926, 0.0097, 0.8664],
[0.2876, 0.2839, 0.3872]]]])
tensor([[[[0.8752, 0.5030, 0.9646],
[0.9026, 0.9058, 0.6884],
[0.5935, 0.4829, 0.1357]],
[[0.7352, 0.2270, 0.1895],
[0.5924, 0.6434, 0.4961],
[0.5629, 0.5115, 0.3605]],
[[0.2796, 0.9790, 0.2157],
[0.7628, 0.7351, 0.5707],
[0.1586, 0.6864, 0.0120]],
[[0.2966, 0.9805, 0.9305],
[0.1588, 0.3001, 0.4055],
[0.1049, 0.2991, 0.2499]]]])
tensor([[[[0.8597, 0.8146, 0.0946],
[0.0017, 0.4607, 0.4496],
[0.3836, 0.4743, 0.7664]],
[[0.5793, 0.2827, 0.3936],
[0.3975, 0.8144, 0.6542],
[0.7341, 0.9724, 0.4592]],
[[0.2093, 0.9717, 0.2315],
[0.4153, 0.3151, 0.7746],
[0.7166, 0.9342, 0.6998]],
[[0.4895, 0.7458, 0.7343],
[0.4585, 0.2103, 0.1174],
[0.6044, 0.9984, 0.1777]]]])
x = torch.zeros((1, 4, 6, 6))
for i in range(0,3):
for j in range(0, 3):
x[0, :, 2*i, 2*j] = x1[0,:,i,j]
x[0, :, 2*i+1, 2*j] = x2[0,:,i,j]
x[0, :, 2*i, 2*j+1] = x3[0,:,i,j]
x[0, :, 2*i+1, 2*j+1] = x4[0,:,i,j]
print(x)
最后得到的结果:
tensor([[[[0.1380, 0.8752, 0.9205, 0.5030, 0.1776, 0.9646],
[0.1295, 0.8597, 0.1420, 0.8146, 0.0031, 0.0946],
[0.7586, 0.9026, 0.9378, 0.9058, 0.0803, 0.6884],
[0.6151, 0.0017, 0.8103, 0.4607, 0.5446, 0.4496],
[0.1316, 0.5935, 0.4958, 0.4829, 0.4530, 0.1357],
[0.9363, 0.3836, 0.7179, 0.4743, 0.6631, 0.7664]],
[[0.9914, 0.7352, 0.2687, 0.2270, 0.5410, 0.1895],
[0.5388, 0.5793, 0.4465, 0.2827, 0.6568, 0.3936],
[0.5149, 0.5924, 0.6382, 0.6434, 0.1535, 0.4961],
[0.1802, 0.3975, 0.8209, 0.8144, 0.1948, 0.6542],
[0.6628, 0.5629, 0.2382, 0.5115, 0.4334, 0.3605],
[0.0764, 0.7341, 0.3371, 0.9724, 0.1915, 0.4592]],
[[0.7072, 0.2796, 0.7634, 0.9790, 0.0528, 0.2157],
[0.6165, 0.2093, 0.3070, 0.9717, 0.1122, 0.2315],
[0.4097, 0.7628, 0.5194, 0.7351, 0.2898, 0.5707],
[0.0084, 0.4153, 0.7888, 0.3151, 0.5461, 0.7746],
[0.2384, 0.1586, 0.0681, 0.6864, 0.0125, 0.0120],
[0.0845, 0.7166, 0.0765, 0.9342, 0.7305, 0.6998]],
[[0.0090, 0.2966, 0.9438, 0.9805, 0.3352, 0.9305],
[0.9088, 0.4895, 0.5142, 0.7458, 0.0295, 0.7343],
[0.8833, 0.1588, 0.1081, 0.3001, 0.6163, 0.4055],
[0.6926, 0.4585, 0.0097, 0.2103, 0.8664, 0.1174],
[0.5972, 0.1049, 0.6735, 0.2991, 0.7644, 0.2499],
[0.2876, 0.6044, 0.2839, 0.9984, 0.3872, 0.1777]]]])
如果使用神经网络:
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential( # input shape (4, 6, 6)
nn.Conv2d(
in_channels=4, # input height
out_channels=2, # n_filters
kernel_size=3, # filter size
stride=1, # filter movement/step
padding=1,
# if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
),
nn.BatchNorm2d(2, eps=0.001, momentum=0.03, affine=True, track_running_stats=True),
nn.SiLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 316, 316)
)
self.conv2 = nn.Sequential( # input shape (4, 6, 6)
nn.Conv2d(
in_channels=4, # input height
out_channels=2, # n_filters
kernel_size=5, # filter size
stride=1, # filter movement/step
padding=2,
# if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
),
nn.BatchNorm2d(2, eps=0.001, momentum=0.03, affine=True, track_running_stats=True),
nn.SiLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 316, 316)
)
self.conv3 = nn.Sequential( # input shape (4, 6, 6)
nn.Conv2d(
in_channels=4, # input height
out_channels=2, # n_filters
kernel_size=7, # filter size
stride=1, # filter movement/step
padding=3,
# if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
),
nn.BatchNorm2d(2, eps=0.001, momentum=0.03, affine=True, track_running_stats=True),
nn.SiLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 316, 316)
)
self.conv4 = nn.Sequential( # input shape (4, 6, 6)
nn.Conv2d(
in_channels=4, # input height
out_channels=2, # n_filters
kernel_size=9, # filter size
stride=1, # filter movement/step
padding=4,
# if want same width and length of this image after Conv2d, padding=(kernel_size-1)/2 if stride=1
),
nn.BatchNorm2d(2, eps=0.001, momentum=0.03, affine=True, track_running_stats=True),
nn.SiLU(), # activation
nn.MaxPool2d(kernel_size=2), # choose max value in 2x2 area, output shape (16, 316, 316)
)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x)
x3 = self.conv3(x)
x4 = self.conv4(x)
meida = torch.rand((1, 2, 6, 6))
for i in range(0,3):
for j in range(0, 3):
meida[0, :, 2*i, 2*j] = x1[0,:,i,j]
meida[0, :, 2*i+1, 2*j] = x2[0,:,i,j]
meida[0, :, 2*i, 2*j+1] = x3[0,:,i,j]
meida[0, :, 2*i+1, 2*j+1] = x4[0,:,i,j]
return meida # return x for visualization
最后前向传递时:
y = torch.rand((1, 2, 6, 6))
net1 = CNN()
optimizer = torch.optim.Adam(net1.parameters(), lr=0.001) # optimize all cnn parameters
loss_func = nn.MSELoss() # the target label is not one-hotted
for i in range(0, 5):
out = net1(x)
print(out)
loss = loss_func(y.float(), out.float())
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()
tensor([[[[ 0.9016, 0.4131, 1.0303, 1.8486, 0.7399, 1.6501],
[ 0.3555, 0.4024, -0.2194, -0.1400, -0.0212, -0.2051],
[ 1.2377, -0.1892, 0.0386, 1.0323, 1.0347, 0.5377],
[ 1.4075, 1.8427, 0.7063, 0.3648, 0.2003, 0.3392],
[ 0.7057, 0.0325, 0.4709, 0.1748, 1.7497, 0.5532],
[ 1.3813, 2.0583, 1.8566, 1.0775, 0.3454, 0.7321]],
[[ 0.3492, 1.8865, 3.3384, 0.0781, 1.0440, -0.1159],
[ 1.3289, 0.4629, 0.4022, -0.1841, -0.0776, -0.1137],
[ 0.4659, 1.9477, 0.8178, 0.0798, 0.0837, -0.2159],
[ 1.8844, 1.1195, 0.3197, 1.4200, -0.1108, 0.2768],
[ 0.3230, 0.5641, 0.5763, 0.6091, 0.2812, -0.0158],
[ 0.7778, -0.1344, 0.6206, 1.3715, -0.1976, 1.2581]]]],
grad_fn=)
tensor(0.6613, grad_fn=)
tensor([[[[ 0.8990, 0.3239, 1.0170, 1.5695, 0.7333, 1.5553],
[ 0.3600, 0.2615, -0.2155, -0.1731, 0.0190, -0.1857],
[ 1.2168, -0.1889, 0.0361, 1.0987, 0.9943, 0.5901],
[ 1.3173, 1.4860, 0.7967, 0.3845, 0.2338, 0.4434],
[ 0.7157, 0.0791, 0.4828, 0.2435, 1.7413, 0.6953],
[ 1.3505, 1.8685, 1.8037, 1.2310, 0.3586, 0.8277]],
[[ 0.3342, 1.8011, 3.2821, 0.0930, 1.0717, -0.1247],
[ 1.2648, 0.4790, 0.4416, -0.1916, -0.0708, -0.0669],
[ 0.4872, 1.8047, 0.8081, 0.1078, 0.0991, -0.2224],
[ 1.8261, 0.8794, 0.3289, 1.1405, -0.1094, 0.4351],
[ 0.3484, 0.6295, 0.5573, 0.6343, 0.2667, -0.0140],
[ 0.8026, -0.1500, 0.6216, 1.2892, -0.1935, 1.4045]]]],
grad_fn=)
tensor(0.5859, grad_fn=)
tensor([[[[ 0.8956, 0.2482, 1.0158, 1.3058, 0.7264, 1.4573],
[ 0.3683, 0.1539, -0.2119, -0.1935, 0.0648, -0.1530],
[ 1.1957, -0.1817, 0.0338, 1.1497, 0.9552, 0.6347],
[ 1.2274, 1.1734, 0.8838, 0.3882, 0.2747, 0.5580],
[ 0.7250, 0.1435, 0.4942, 0.3161, 1.7321, 0.8241],
[ 1.3121, 1.6799, 1.7506, 1.3245, 0.3710, 0.8787]],
[[ 0.3427, 1.6913, 3.2248, 0.0997, 1.0978, -0.1392],
[ 1.2202, 0.5519, 0.4833, -0.1836, -0.0628, -0.0131],
[ 0.5076, 1.6752, 0.7976, 0.1411, 0.1150, -0.2269],
[ 1.7656, 0.6791, 0.3375, 1.1367, -0.1074, 0.6079],
[ 0.3733, 0.6904, 0.5378, 0.6658, 0.2523, -0.0047],
[ 0.8281, -0.1553, 0.6216, 1.1712, -0.1891, 1.4157]]]],
grad_fn=)
tensor(0.5275, grad_fn=)
tensor([[[[ 0.8932, 0.1947, 1.0207, 1.0796, 0.7178, 1.3599],
[ 0.3780, 0.0801, -0.2085, -0.2067, 0.1172, -0.1066],
[ 1.1723, -0.1656, 0.0321, 1.1593, 0.9213, 0.6406],
[ 1.1385, 0.9151, 0.9679, 0.3737, 0.3220, 0.6726],
[ 0.7334, 0.2223, 0.5050, 0.3709, 1.7178, 0.9314],
[ 1.2681, 1.6294, 1.6939, 1.3654, 0.3844, 0.8869]],
[[ 0.3600, 1.5667, 3.1680, 0.1029, 1.1217, -0.1533],
[ 1.2001, 0.6613, 0.5188, -0.1743, -0.0576, 0.0289],
[ 0.5248, 1.5532, 0.7859, 0.1732, 0.1328, -0.2303],
[ 1.7019, 0.5189, 0.3486, 1.0700, -0.1049, 0.7613],
[ 0.3959, 0.7454, 0.5178, 0.6952, 0.2388, 0.0073],
[ 0.8511, -0.1518, 0.6184, 1.0351, -0.1859, 1.3669]]]],
grad_fn=)
tensor(0.4780, grad_fn=)
tensor([[[[ 0.8897, 0.1617, 1.0232, 0.8937, 0.7087, 1.2684],
[ 0.3890, 0.0247, -0.2053, -0.2161, 0.1760, -0.0453],
[ 1.1492, -0.1393, 0.0309, 1.1370, 0.8897, 0.6435],
[ 1.0512, 0.6983, 1.0468, 0.3425, 0.3738, 0.7758],
[ 0.7416, 0.3163, 0.5155, 0.4047, 1.7025, 1.0142],
[ 1.2181, 1.5586, 1.6324, 1.3578, 0.3984, 0.8748]],
[[ 0.3771, 1.4487, 3.1112, 0.1041, 1.1430, -0.1658],
[ 1.1689, 0.7852, 0.5506, -0.1648, -0.0538, 0.0702],
[ 0.5406, 1.4380, 0.7727, 0.2035, 0.1515, -0.2326],
[ 1.6384, 0.3959, 0.3605, 0.9520, -0.1018, 0.8885],
[ 0.4178, 0.7936, 0.4966, 0.7211, 0.2248, 0.0223],
[ 0.8742, -0.1433, 0.6147, 0.9079, -0.1826, 1.3010]]]],
grad_fn=)
tensor(0.4340, grad_fn=)
可以看出不受影响