pytorch c++ Conv2d

代码

#include 
#include 

#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 
#include 

#include 
#include 
#include 
#include 
#include 
#include 
#include 

using namespace std;
using namespace at;
using namespace torch::nn;
using namespace torch::optim;

int main(int argc, const char* argv[])
{
    auto x = torch::randn({2, 3, 5, 5}, torch::requires_grad());    # N * C * H * W
    torch::nn::Conv2d model(Conv2dOptions(3, 2, {3,3}).stride(2).padding(1));
    auto y = model->forward(x);
    std::cout << y.sizes() << endl;

    std::cout<< "ok\n";
    return 1;
}

编译

make
./bin/demo

结果

[2, 2, 3, 3]
ok

你可能感兴趣的:(pytorch)