构建卷积网络
#include
#include
#include
using std::cout; using std::endl;
class LinearBnReluImpl : public torch::nn::Module
{
private:
torch::nn::Linear ln{ nullptr };
torch::nn::BatchNorm1d bn{ nullptr };
public:
LinearBnReluImpl(int input_features, int out_features);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(LinearBnRelu);
inline torch::nn::Conv2dOptions conv_options(
int64_t in_planes, int64_t out_planes, int64_t kernel_size,
int64_t stride = 1, int64_t padding = 0, bool with_bias = false
)
{
torch::nn::Conv2dOptions conv_options = torch::nn::Conv2dOptions(in_planes, out_planes, kernel_size);
conv_options.stride(stride);
conv_options.padding(padding);
conv_options.bias(with_bias);
return conv_options;
}
class ConvReluBnImpl : public torch::nn::Module
{
private:
torch::nn::Conv2d conv{ nullptr };
torch::nn::BatchNorm2d bn{ nullptr };
public:
ConvReluBnImpl(int input_channel, int output_channel, int kernel_size, int stride, int padding=1);
torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(ConvReluBn);
class MLP : public torch::nn::Module
{
private:
int mid_features[3] = { 32, 64, 128 };
LinearBnRelu ln1{ nullptr };
LinearBnRelu ln2{ nullptr };
LinearBnRelu ln3{ nullptr };
torch::nn::Linear out_ln{ nullptr };
public:
MLP(int in_features, int out_features);
torch::Tensor forward(torch::Tensor x);
};
class plainCNN : public torch::nn::Module
{
private:
int mid_channels[3]{ 32,64,128 };
ConvReluBn conv1{ nullptr };
ConvReluBn down1{ nullptr };
ConvReluBn conv2{ nullptr };
ConvReluBn down2{ nullptr };
ConvReluBn conv3{ nullptr };
ConvReluBn down3{ nullptr };
torch::nn::Conv2d out_conv{ nullptr };
public:
plainCNN(int in_channels, int out_channels);
torch::Tensor forward(torch::Tensor x);
};
int main()
{
plainCNN c(3, 2);
auto x = torch::rand({ 1,3,224,224 }, torch::kFloat);
auto a = c.forward(x);
cout <<"[in Main]: "<< a.sizes() << endl;
return 0;
}
LinearBnReluImpl::LinearBnReluImpl(int input_features, int out_features)
{
ln = register_module("ln", torch::nn::Linear(torch::nn::LinearOptions(input_features, out_features)));
bn = register_module("bn", torch::nn::BatchNorm1d(out_features));
}
torch::Tensor LinearBnReluImpl::forward(torch::Tensor x)
{
x = torch::relu(ln->forward(x));
x = bn(x);
return x;
}
ConvReluBnImpl::ConvReluBnImpl(int input_channel, int output_channel, int kernel_size, int stride, int padding)
{
conv = register_module("conv", torch::nn::Conv2d(conv_options(input_channel, output_channel, kernel_size, stride, padding)));
bn = register_module("bn", torch::nn::BatchNorm2d(output_channel));
}
torch::Tensor ConvReluBnImpl::forward(torch::Tensor x)
{
x = torch::relu(conv->forward(x));
x = bn(x);
return x;
}
MLP::MLP(int in_features, int out_features)
{
ln1 = LinearBnRelu(in_features, mid_features[0]);
ln2 = LinearBnRelu(mid_features[0], mid_features[1]);
ln3 = LinearBnRelu(mid_features[1], mid_features[2]);
out_ln = torch::nn::Linear(mid_features[2], out_features);
ln1 = register_module("ln1", ln1);
ln2 = register_module("ln2", ln2);
ln3 = register_module("ln3", ln3);
out_ln = register_module("out_ln", out_ln);
}
torch::Tensor MLP::forward(torch::Tensor x)
{
x = ln1->forward(x);
x = ln2->forward(x);
x = ln3->forward(x);
x = out_ln->forward(x);
return x;
}
plainCNN::plainCNN(int in_channels, int out_channels)
{
conv1 = ConvReluBn(in_channels, mid_channels[0], 3, 1);
down1 = ConvReluBn(mid_channels[0], mid_channels[0], 3, 2);
conv2 = ConvReluBn(mid_channels[0], mid_channels[1], 3,1);
down2 = ConvReluBn(mid_channels[1], mid_channels[1], 3, 2);
conv3 = ConvReluBn(mid_channels[1], mid_channels[2], 3,1);
down3 = ConvReluBn(mid_channels[2], mid_channels[2], 3, 2);
out_conv = torch::nn::Conv2d(conv_options(mid_channels[2], out_channels, 3));
conv1 = register_module("conv1", conv1);
down1 = register_module("down1", down1);
conv2 = register_module("conv2", conv2);
down2 = register_module("down2", down2);
conv3 = register_module("conv3", conv3);
down3 = register_module("down3", down3);
out_conv = register_module("out_conv", out_conv);
}
torch::Tensor plainCNN::forward(torch::Tensor x)
{
x = conv1->forward(x);
cout << x.sizes() << endl;
x = down1->forward(x);
cout << x.sizes() << endl;
x = conv2->forward(x);
cout << x.sizes() << endl;
x = down2->forward(x);
cout << x.sizes() << endl;
x = conv3->forward(x);
cout << x.sizes() << endl;
x = down3->forward(x);
cout << x.sizes() << endl;
x = out_conv->forward(x);
cout << x.sizes() << endl;
return x;
}
结果