yolov5 hard swish实现

github上看到的

原作者是

auto hsig = network->addActivation(*bn1->getOutput(0), ActivationType::kHARD_SIGMOID);
sig->setAlpha(1.0 / 6.0);
hsig->setBeta(0.5);
auto ew = network->addElementWise(*bn1->getOutput(0), *hsig->getOutput(0), ElementWiseOperation::kPROD);
return ew;

有网友改的

// 大佬,你好,目前hard swish是以插件形式实现的,感觉速度慢太多,同时插件只支持fp32模式。
// 我自己基于tensorrt api实现了hard swish,速度比插件要快不少。
//hard swish
Weights emptywts{ DataType::kFLOAT, nullptr, 0 };
float scval = reinterpret_cast(malloc(sizeof(float)));
scval[0] = 1.0 / 6.0;
Weights scale{ DataType::kFLOAT, scval, 1 };
float shval = reinterpret_cast(malloc(sizeof(float)));
shval[0] = 0.5;
Weights shift{ DataType::kFLOAT, shval, 1 };
float pval = reinterpret_cast(malloc(sizeof(float)));
pval[0] = 1.0;
Weights power{ DataType::kFLOAT, pval, 1 };
auto clip = network->addActivation(*bn1->getOutput(0), ActivationType::kCLIP);
clip->setAlpha(-3.0);
clip->setBeta(3.0);
auto sc = network->addScale(*clip->getOutput(0), ScaleMode::kUNIFORM, shift, scale, emptywts);
auto hs = network->addElementWise(*sc->getOutput(0), *bn1->getOutput(0), ElementWiseOperation::kPROD);
return hs;
// 除此之外,addBatchNorm2d中的eps应该设置为1e-5,而不是1e-3,pytorch里面默认使用的是1e-5。

https://github.com/wang-xinyu/tensorrtx/issues/182

你可能感兴趣的:(TensorRT,深度学习)