Java调用Pytorch实现以图搜图(附源码)

Java调用Pytorch实现以图搜图

设计技术栈:
1、ElasticSearch环境;
2、Python运行环境(如果事先没有pytorch模型时,可以用python脚本创建模型);

1、运行效果

Java调用Pytorch实现以图搜图(附源码)_第1张图片

2、创建模型(有则可以跳过)

vi script.py

import torch
import torch.nn as nn
import torchvision.models as models
 
class ImageFeatureExtractor(nn.Module):
    def __init__(self):
        super(ImageFeatureExtractor, self).__init__()
        self.resnet = models.resnet50(pretrained=True)
        #最终输出维度1024的向量,下文elastic search要设置dims为1024
        self.resnet.fc = nn.Linear(2048, 1024)
 
    def forward(self, x):
        x = self.resnet(x)
        return x
 
if __name__ == '__main__':
    model = ImageFeatureExtractor()
    model.eval()
    #根据模型随便创建一个输入
    input = torch.rand([1, 3, 224, 224])
    output = model(input)
    #以这种方式保存
    script = torch.jit.trace(model, input)
    script.save("model.pt")

2、java项目pom.xml


		
			org.springframework.boot
			spring-boot-starter-web
		
		
			org.projectlombok
			lombok
			provided
		
		
            ai.djl.pytorch
            pytorch-engine
            0.19.0
        
        
            ai.djl.pytorch
            pytorch-native-cpu
            1.10.0
            runtime
        
        
            ai.djl.pytorch
            pytorch-jni
            1.10.0-0.19.0
        
        
            org.elasticsearch.client
            elasticsearch-rest-high-level-client
        
	

3、ES创建文档

PUT /isi
{
  "mappings": {
    "properties": {
      "vector": {
        "type": "dense_vector",
        "dims": 1024
      },
      "url" : {
        "type" : "keyword"
      },
      "user_id": {
          "type": "keyword"
      }
    }
  }
}

4、编写java代码调用模型

ORCUtil.java

package com.topprismcloud.rtm;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.util.NDImageUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.RequestOptions;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestClientBuilder;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.ScriptQueryBuilder;
import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.script.Script;
import org.elasticsearch.script.ScriptType;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xcontent.XContentType;

import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.net.URL;
import java.nio.file.Paths;
import java.util.*;

public class ORCUtil {

	private static final String INDEX = "isi";
	private static final int IMAGE_SIZE = 224;
	private static Model model; // 模型
	private static Predictor predictor; // predictor.predict(input)相当于python中model(input)
	static {
		try {
			model = Model.newInstance("model");
			// 这里的model.pt是上面代码展示的那种方式保存的
			model.load(ORCUtil.class.getClassLoader().getResourceAsStream("model.pt"));
			Transform resize = new Resize(IMAGE_SIZE);
			Transform toTensor = new ToTensor();
			Transform normalize = new Normalize(new float[] { 0.485f, 0.456f, 0.406f },
					new float[] { 0.229f, 0.224f, 0.225f });
			// Translator处理输入Image转为tensor、输出转为float[]
			Translator translator = new Translator() {
				@Override
				public NDList processInput(TranslatorContext ctx, Image input) throws Exception {
					NDManager ndManager = ctx.getNDManager();
					System.out.println("input: " + input.getWidth() + ", " + input.getHeight());
					NDArray transform = normalize
							.transform(toTensor.transform(resize.transform(input.toNDArray(ndManager))));
					System.out.println(transform.getShape());
					NDList list = new NDList();
					list.add(transform);
					return list;
				}

				@Override
				public float[] processOutput(TranslatorContext ctx, NDList ndList) throws Exception {
					return ndList.get(0).toFloatArray();
				}
			};
			predictor = new Predictor<>(model, translator, Device.cpu(), true);
		} catch (Exception e) {
			e.printStackTrace();
		}
	}

	public static void upload() throws Exception {
		HttpHost host=new HttpHost("14.20.30.16", 9200, HttpHost.DEFAULT_SCHEME_NAME);
		RestClientBuilder builder=RestClient.builder(host);
		CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
		credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials("elastic", "123456"));
		builder.setHttpClientConfigCallback(f -> f.setDefaultCredentialsProvider(credentialsProvider));
		RestHighLevelClient client = new RestHighLevelClient( builder);
		// 批量上传请求
		BulkRequest bulkRequest = new BulkRequest(INDEX);
		File file = new File("D:\\001ENV\\nginx-1.24.0\\html\\resource\\new");
		for (File listFile : file.listFiles()) {
//			float[] vector = predictor.predict(ImageFactory.getInstance()
//					.fromInputStream(Test.class.getClassLoader().getResourceAsStream("new/" + listFile.getName())));
			
			float[] vector = predictor.predict(ImageFactory.getInstance()
					.fromInputStream(new FileInputStream(listFile)));
			// 构建文档
			Map jsonMap = new HashMap<>();
			jsonMap.put("url", "/resource/"+listFile.getName());
			jsonMap.put("vector", vector);
			jsonMap.put("user_id", "user123");
			IndexRequest request = new IndexRequest(INDEX).source(jsonMap, XContentType.JSON);
			bulkRequest.add(request);
		}
		client.bulk(bulkRequest, RequestOptions.DEFAULT);
		client.close();
	}

	// 接收待搜索图片的inputstream,搜索与其相似的图片
	public static List search(InputStream input) throws Throwable {
		float[] vector = predictor.predict(ImageFactory.getInstance().fromInputStream(input));
		System.out.println(Arrays.toString(vector));

		// 展示k个结果
		int k = 100;
		// 连接Elasticsearch服务器
		RestHighLevelClient client = new RestHighLevelClient(
				RestClient.builder(new HttpHost("14.20.30.16", 9200, "http")));

		SearchRequest searchRequest = new SearchRequest(INDEX);
		Script script = new Script(ScriptType.INLINE, "painless", "cosineSimilarity(params.queryVector, doc['vector'])",
				Collections.singletonMap("queryVector", vector));

		FunctionScoreQueryBuilder functionScoreQueryBuilder = QueryBuilders
				.functionScoreQuery(QueryBuilders.matchAllQuery(), ScoreFunctionBuilders.scriptFunction(script));

		SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
		searchSourceBuilder.query(functionScoreQueryBuilder).fetchSource(null, "vector") // 不返回vector字段,太多了没用还耗时
				.size(k);

		searchRequest.source(searchSourceBuilder);

		SearchResponse searchResponse = client.search(searchRequest, RequestOptions.DEFAULT);

		SearchHits hits = searchResponse.getHits();

		List list = new ArrayList<>();
		for (SearchHit hit : hits) {
			// 处理搜索结果
			System.out.println(hit.toString());
			SearchResult result = new SearchResult((String) hit.getSourceAsMap().get("url"), hit.getScore());
			list.add(result);
		}

		client.close();
		return list;
	}

	public static void main(String[] args) throws Throwable {
		ORCUtil.upload();
		System.out.println("hao");
	}
}

SearchController.java

package com.topprismcloud.rtm;

import java.util.List;

import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.CrossOrigin;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.multipart.MultipartFile;

@RestController
@CrossOrigin
public class SearchController {

	@PostMapping("search")
	public ResponseEntity search(MultipartFile file) {
		try {
			List list = ORCUtil.search(file.getInputStream());
			return ResponseEntity.ok(list);
		} catch (Throwable e) {
			return ResponseEntity.status(400).body(null);
		}
	}
}

SearchResult.java

package com.topprismcloud.rtm;

import lombok.AllArgsConstructor;
import lombok.Data;

@Data
@AllArgsConstructor
public class SearchResult {
    private String url;
    private Float score;
}

5、前端

index.html





    
    以图搜图
    
    



    
请选择图片

请选择图片

6、打包后的源代码

以图搜图Java+html源代码

相关参考文章:Java调用Pytorch模型进行图像识别

你可能感兴趣的:(java,pytorch,开发语言)