libtorch学习第六

构建卷积网络

#include
#include
#include

using std::cout; using std::endl;

class LinearBnReluImpl : public torch::nn::Module
{
private:
	torch::nn::Linear ln{ nullptr };
	torch::nn::BatchNorm1d bn{ nullptr };

public:
	LinearBnReluImpl(int input_features, int out_features);
	torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(LinearBnRelu);

inline torch::nn::Conv2dOptions conv_options(
	int64_t in_planes, int64_t out_planes, int64_t kernel_size,
	int64_t stride = 1, int64_t padding = 0, bool with_bias = false
)
{
	torch::nn::Conv2dOptions conv_options = torch::nn::Conv2dOptions(in_planes, out_planes, kernel_size);
	conv_options.stride(stride);
	conv_options.padding(padding);
	conv_options.bias(with_bias);

	return conv_options;
}

class ConvReluBnImpl : public torch::nn::Module
{
private:
	torch::nn::Conv2d conv{ nullptr };
	torch::nn::BatchNorm2d bn{ nullptr };
	
public:
	ConvReluBnImpl(int input_channel, int output_channel, int kernel_size, int stride, int padding=1);
	torch::Tensor forward(torch::Tensor x);
};
TORCH_MODULE(ConvReluBn);


class MLP : public torch::nn::Module
{
private:
	int mid_features[3] = { 32, 64, 128 };
	LinearBnRelu ln1{ nullptr };
	LinearBnRelu ln2{ nullptr };
	LinearBnRelu ln3{ nullptr };
	torch::nn::Linear out_ln{ nullptr };

public:
	MLP(int in_features, int out_features);
	torch::Tensor forward(torch::Tensor x);
};

class plainCNN : public torch::nn::Module
{
private:
	int mid_channels[3]{ 32,64,128 };
	ConvReluBn conv1{ nullptr };
	ConvReluBn down1{ nullptr };
	ConvReluBn conv2{ nullptr };
	ConvReluBn down2{ nullptr };
	ConvReluBn conv3{ nullptr };
	ConvReluBn down3{ nullptr };
	torch::nn::Conv2d out_conv{ nullptr };

public:
	plainCNN(int in_channels, int out_channels);
	torch::Tensor forward(torch::Tensor x);
};


int main()
{
	plainCNN c(3, 2);

	auto x = torch::rand({ 1,3,224,224 }, torch::kFloat);
	//cout << x.sizes() << endl;
	
	auto a = c.forward(x);
	cout <<"[in Main]: "<< a.sizes() << endl;

	return 0;
}

LinearBnReluImpl::LinearBnReluImpl(int input_features, int out_features)
{
	ln = register_module("ln", torch::nn::Linear(torch::nn::LinearOptions(input_features, out_features)));
	bn = register_module("bn", torch::nn::BatchNorm1d(out_features));
}

torch::Tensor LinearBnReluImpl::forward(torch::Tensor x)
{
	x = torch::relu(ln->forward(x));
	x = bn(x);
	return x;
}

ConvReluBnImpl::ConvReluBnImpl(int input_channel, int output_channel, int kernel_size, int stride, int padding)
{
	conv = register_module("conv", torch::nn::Conv2d(conv_options(input_channel, output_channel, kernel_size, stride, padding)));
	bn = register_module("bn", torch::nn::BatchNorm2d(output_channel));
}

torch::Tensor ConvReluBnImpl::forward(torch::Tensor x)
{
	x = torch::relu(conv->forward(x));
	x = bn(x);
	return x;
}

MLP::MLP(int in_features, int out_features)
{
	ln1 = LinearBnRelu(in_features, mid_features[0]);
	ln2 = LinearBnRelu(mid_features[0], mid_features[1]);
	ln3 = LinearBnRelu(mid_features[1], mid_features[2]);
	out_ln = torch::nn::Linear(mid_features[2], out_features);

	ln1 = register_module("ln1", ln1);
	ln2 = register_module("ln2", ln2);
	ln3 = register_module("ln3", ln3);
	out_ln = register_module("out_ln", out_ln);
}

torch::Tensor MLP::forward(torch::Tensor x)
{
	x = ln1->forward(x);
	x = ln2->forward(x);
	x = ln3->forward(x);
	x = out_ln->forward(x);
	return x;
}

plainCNN::plainCNN(int in_channels, int out_channels)
{
	conv1 = ConvReluBn(in_channels, mid_channels[0], 3, 1);
	down1 = ConvReluBn(mid_channels[0], mid_channels[0], 3, 2);
	conv2 = ConvReluBn(mid_channels[0], mid_channels[1], 3,1);
	down2 = ConvReluBn(mid_channels[1], mid_channels[1], 3, 2);
	conv3 = ConvReluBn(mid_channels[1], mid_channels[2], 3,1);
	down3 = ConvReluBn(mid_channels[2], mid_channels[2], 3, 2);
	out_conv = torch::nn::Conv2d(conv_options(mid_channels[2], out_channels, 3));

	conv1 = register_module("conv1", conv1);
	down1 = register_module("down1", down1);
	conv2 = register_module("conv2", conv2);
	down2 = register_module("down2", down2);
	conv3 = register_module("conv3", conv3);
	down3 = register_module("down3", down3);
	out_conv = register_module("out_conv", out_conv);
}

torch::Tensor plainCNN::forward(torch::Tensor x)
{
	x = conv1->forward(x);
	cout << x.sizes() << endl;
	x = down1->forward(x);
	cout << x.sizes() << endl;
	x = conv2->forward(x);
	cout << x.sizes() << endl;
	x = down2->forward(x);
	cout << x.sizes() << endl;
	x = conv3->forward(x);
	cout << x.sizes() << endl;
	x = down3->forward(x);
	cout << x.sizes() << endl;
	x = out_conv->forward(x);
	cout << x.sizes() << endl;
	return x;
}


结果

libtorch学习第六_第1张图片

你可能感兴趣的:(pytorch,学习,pytorch,人工智能)