深度学习pytorch代码:非线性激活Relu()

 input为ReLu()中的一个参数,默认为Faluse,保留输入数据

import torch
from torch.nn import ReLU
from torch import nn

input =torch.tensor([
    [1, -0.5], [-1, 3]    # 1为batchsize
])
output = torch.reshape(input, (-1, 1, 2, 2))
print(output.shape)


class LR(nn.Module):
    def __init__(self):
        super(LR, self).__init__()
        self.relu1 = ReLU()

    def forward(self, input):
        output = self.relu1(input)
        return output

lrp = LR()
output = lrp(input)
print(output)
# tensor([[1., 0.],
#         [0., 3.]])     当x<0,x=0



深度学习pytorch代码:非线性激活Relu()_第1张图片

 

你可能感兴趣的:(pytorch学习,人工智能,pytorch,机器学习,python,图像处理)