pytorch android部署 demo 用自己训练的自定义模型踩坑记录

记录一个用自己定义的模型(一个稍微改了分类数目的vgg网络,分40类)加到github项目里面时遇到的小坑:

2021-01-26 19:02:42.191 19212-19370/org.pytorch.demo E/AndroidRuntime: FATAL EXCEPTION: ModuleActivity
    Process: org.pytorch.demo, PID: 19212
    com.facebook.jni.CppException: 
    
    aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor):
    Expected at most 12 arguments but found 13 positional arguments.
    :
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py(419): _conv_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/conv.py(423): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/container.py(117): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/garbage_classify/mycode/vgg.py(42): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/garbage_classify/mycode/myVGG.py(25): forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(709): _slow_forward
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/nn/modules/module.py(725): _call_impl
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/jit/_trace.py(934): trace_module
    /home/xutengfei/.local/lib/python3.8/site-packages/torch/jit/_trace.py(733): trace
    /home/xutengfei/garbage_classify/mycode/deployment_script.py(32): 
    Serialized   File "code/__torch__/torch/nn/modules/conv.py", line 10
        input: Tensor) -> Tensor:
        _0 = self.bias
        input0 = torch._convolution(input, self.weight, _0, [1, 1], [1, 1], [1, 1], False, [0, 0], 1, False, False, True, True)
                 ~~~~~~~~~~~~~~~~~~ <--- HERE
        return input0
    
        at org.pytorch.NativePeer.initHybrid(Native Method)
        at org.pytorch.NativePeer.(NativePeer.java:24)
        at org.pytorch.Module.load(Module.java:23)
        at org.pytorch.demo.vision.ImageClassificationActivity.analyzeImage(ImageClassificationActivity.java:166)
        at org.pytorch.demo.vision.ImageClassificationActivity.analyzeImage(ImageClassificationActivity.java:31)
        at org.pytorch.demo.vision.AbstractCameraXActivity.lambda$setupCameraX$2$AbstractCameraXActivity(AbstractCameraXActivity.java:90)
        at org.pytorch.demo.vision.-$$Lambda$AbstractCameraXActivity$t0OjLr-l_M0-_0_dUqVE4yqEYnE.analyze(Unknown Source:2)
        at androidx.camera.core.ImageAnalysisAbstractAnalyzer.analyzeImage(ImageAnalysisAbstractAnalyzer.java:57)
        at androidx.camera.core.ImageAnalysisNonBlockingAnalyzer$1.run(ImageAnalysisNonBlockingAnalyzer.java:135)
        at android.os.Handler.handleCallback(Handler.java:900)
        at android.os.Handler.dispatchMessage(Handler.java:103)
        at android.os.Looper.loop(Looper.java:219)
        at android.os.HandlerThread.run(HandlerThread.java:67)

抓住其中的错误提示:Expected at most 12 arguments but found 13 positional arguments.
仔细对照参数:发现确实多了一个true:

aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor)
"input", "self.weight", _0, "[1, 1]", "[1, 1]", "[1, 1]", "False", "[0, 0]", "1", "False, False, True", True 就是这个有问题了。

然后根据错误提示到网上查阅相关资料,推断可能是版本问题。
之后果然在github的issue里面找到了想要的答案:修改build.gradle里面的pytorch-android为最新版本即可!

implementation 'org.pytorch:pytorch_android:1.7.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.7.0'

你可能感兴趣的:(android,深度学习,pytorch)