tflite模型下载

在学习tensorflow时的案例应用时,往往可以看到,应用用到了一些模型,其下载过程写在了download.gradle文件中

其下载过程如下

task downloadPosenetModel(type: DownloadUrlTask) {
    def modelPosenetDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite"
    doFirst {
        println "Downloading ${modelPosenetDownloadUrl}"
    }
    sourceUrl = "${modelPosenetDownloadUrl}"
    target = file("src/main/assets/posenet.tflite")
}

task downloadMovenetLightningModel(type: DownloadUrlTask) {
    def modelMovenetLightningDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/singlepose/lightning/tflite/float16/4?lite-format=tflite"
    doFirst {
        println "Downloading ${modelMovenetLightningDownloadUrl}"
    }
    sourceUrl = "${modelMovenetLightningDownloadUrl}"
    target = file("src/main/assets/movenet_lightning.tflite")
}

task downloadMovenetThunderModel(type: DownloadUrlTask) {
    def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/singlepose/thunder/tflite/float16/4?lite-format=tflite"
    doFirst {
        println "Downloading ${modelMovenetThunderDownloadUrl}"
    }
    sourceUrl = "${modelMovenetThunderDownloadUrl}"
    target = file("src/main/assets/movenet_thunder.tflite")
}

task downloadMovenetMultiPoseModel(type: DownloadUrlTask) {
    def modelMovenetThunderDownloadUrl = "https://tfhub.dev/google/lite-model/movenet/multipose/lightning/tflite/float16/1?lite-format=tflite"
    doFirst {
        println "Downloading ${modelMovenetThunderDownloadUrl}"
    }
    sourceUrl = "${modelMovenetThunderDownloadUrl}"
    target = file("src/main/assets/movenet_multipose_fp16.tflite")
}

task downloadPoseClassifierModel(type: DownloadUrlTask) {
    def modelPoseClassifierDownloadUrl = "https://storage.googleapis.com/download.tensorflow.org/models/tflite/pose_classifier/yoga_classifier.tflite"
    doFirst {
        println "Downloading ${modelPoseClassifierDownloadUrl}"
    }
    sourceUrl = "${modelPoseClassifierDownloadUrl}"
    target = file("src/main/assets/classifier.tflite")
}

task downloadModel {
    dependsOn downloadPosenetModel
    dependsOn downloadMovenetLightningModel
    dependsOn downloadMovenetThunderModel
    dependsOn downloadPoseClassifierModel
    dependsOn downloadMovenetMultiPoseModel
}

class DownloadUrlTask extends DefaultTask {
    @Input
    String sourceUrl

    @OutputFile
    File target

    @TaskAction
    void download() {
        ant.get(src: sourceUrl, dest: target)
    }
}

preBuild.dependsOn downloadModel

另外,还有一些应用的tflite模型下载文件如下

task downloadModelFile(type: Download) {
    src 'https://storage.googleapis.com/download.tensorflow.org/models/tflite/digit_classifier/mnist.tflite'
    dest project.ext.ASSET_DIR + '/mnist.tflite'
    overwrite false
}


tasks.whenTaskAdded { task ->
    if (task.name == 'assembleDebug') {
        task.dependsOn 'downloadModelFile'
    }
    if (task.name == 'assembleRelease') {
        task.dependsOn 'downloadModelFile'
    }
}

在国内,往往无法直接通过tfhub.dev和googleapis网站下载到这些模型,因此需要先在外网下载。然后将其放置在app/src/assets文件目录下,并为其改名。

通常建议,将第一种情况写法改为如上第二种模式。

如果不行,则需要清空掉 第1种模式 download.gradle中内容即可。

你可能感兴趣的:(Android,tensorflow,深度学习,人工智能)