io.milvus
milvus-sdk-java
2.2.3
public List getData() {
try {
InputStream inputStream = ManagerAiController.class.getClassLoader().getResourceAsStream("questions/lancode.csv");
byte[] bytes = inputStream.readAllBytes();
String csvString = new String(bytes, "gb2312");
Reader reader = new StringReader(csvString);
CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
List dataList = Lists.newArrayList();
for (CSVRecord record : parser) {
MilvusParam data = new MilvusParam();
data.setQuestion(record.get("question"));
data.setAnswer(record.get("answer"));
dataList.add(data);
}
return dataList;
} catch (Exception e) {
return Lists.newArrayList();
}
}
public List createVectors(List questions) throws IOException {
StringJoiner joiner = new StringJoiner(", ", "[", "]");
for (String question : questions) {
joiner.add("'" + question + "'");
}
String questionList = joiner.toString();
String prefix = "# -*- coding: utf-8 -*-\n";
String nodeSnippet = "from bert_serving.client import BertClient\n" +
"bc = BertClient(ip='localhost', check_version=False, port=5555, port_out=5556, check_length=False, timeout=10000)\n" +
"sentences_list = " + questionList + "\n" +
"output = bc.encode(sentences_list, show_tokens=False, is_tokenized=False, blocking=True)\n" +
"print(output.tolist())";
nodeSnippet = prefix + nodeSnippet;
String[] cmd = {"python", "-c", nodeSnippet};
ProcessBuilder pb = new ProcessBuilder(cmd);
pb.redirectInput(ProcessBuilder.Redirect.PIPE);
pb.redirectOutput(ProcessBuilder.Redirect.PIPE);
pb.redirectError(ProcessBuilder.Redirect.PIPE);
pb.environment().put("PYTHONIOENCODING", "UTF-8");
Process process = pb.start();
process.getOutputStream().close();
String result = result(process);
Gson gson = new Gson();
Type type = new com.google.gson.reflect.TypeToken>>() {
}.getType();
List> resultList = gson.fromJson(result, type);
List floatList = resultList.get(0);
log.info("size:" + floatList.size());
log.info("content:" + floatList);
error(process);
return floatList;
}
@Override
public void createCollection(List milvusList) throws IOException {
// 创建集合
creatCollection();
// 创建分区
createPartition(MilvusConstant.COLLECTION_NAME, MilvusConstant.PARTITION_NAME);
// 将数据加载到内存需要创建索引
createIndex(MilvusConstant.COLLECTION_NAME, MilvusConstant.Field.VECTOR);
// 保存数据到milvus
insert(milvusList);
// 将Milvus数据加载到内存
loadCollection();
}
public void creatCollection() {
FieldType idField = FieldType.newBuilder().withName(MilvusConstant.Field.ID).withDataType(DataType.Int64)
.withPrimaryKey(true).withAutoID(true).build();
FieldType vectorField = FieldType.newBuilder().withName(MilvusConstant.Field.VECTOR)
.withDataType(DataType.FloatVector).withDimension(MilvusConstant.FEATURE_DIM).build();
FieldType questionField = FieldType.newBuilder().withName(MilvusConstant.Field.QUESTION)
.withDataType(DataType.VarChar).withMaxLength(MilvusConstant.QUESTION_LENGTH).build();
FieldType answerField = FieldType.newBuilder().withName(MilvusConstant.Field.ANSWER)
.withDataType(DataType.VarChar).withMaxLength(MilvusConstant.ANSWER_LENGTH).build();
CreateCollectionParam collection = CreateCollectionParam.newBuilder()
.withCollectionName(MilvusConstant.COLLECTION_NAME).withDescription("question answer search")
.withShardsNum(2).addFieldType(idField).addFieldType(vectorField).addFieldType(questionField).addFieldType(answerField).build();
R response = milvusServiceClient.createCollection(collection);
log.info(MilvusConstant.COLLECTION_NAME + "是否成功创建集合——>>" + response.getStatus());
}
public Integer createPartition(String collectionName, String partitionName) {
R response = milvusServiceClient.createPartition(CreatePartitionParam.newBuilder()
.withCollectionName(collectionName)
.withPartitionName(partitionName)
.build());
return response.getStatus();
}
public R createIndex(String collectionName, String fieldName) {
R response = milvusServiceClient.createIndex(CreateIndexParam.newBuilder()
.withCollectionName(collectionName)
.withFieldName(fieldName)
.withIndexName(fieldName)
.withIndexType(IndexType.IVF_FLAT)
.withMetricType(MetricType.IP)
.withExtraParam("{\"nlist\":16384}")
.withSyncMode(Boolean.FALSE)
.build());
log.info("createIndex-------------------->{}", response.toString());
R indexResp = milvusServiceClient.describeIndex(
DescribeIndexParam.newBuilder()
.withCollectionName(collectionName)
.withIndexName(fieldName)
.build());
log.info("after createIndex-------------------->{}", indexResp.getStatus());
return response;
}
public void loadCollection() {
R response = milvusServiceClient.loadPartitions(LoadPartitionsParam.newBuilder()
.withCollectionName(MilvusConstant.COLLECTION_NAME)
.withPartitionNames(Lists.newArrayList(MilvusConstant.PARTITION_NAME)).build());
log.info("Milvus数据加载到内存状态:{}", response.getStatus());
}
public void insert(List milvusParamList) throws IOException {
List questions = new ArrayList<>();
List answers = new ArrayList<>();
List> vectors = new ArrayList<>();
for (MilvusParam vo : milvusParamList) {
vectors.add(milvusVectorsHandler.createVectors(Arrays.asList(vo.getQuestion())));
questions.add(vo.getQuestion());
answers.add(vo.getAnswer());
}
List fields = new ArrayList<>();
fields.add(new InsertParam.Field(MilvusConstant.Field.VECTOR, vectors));
fields.add(new InsertParam.Field(MilvusConstant.Field.QUESTION, questions));
fields.add(new InsertParam.Field(MilvusConstant.Field.ANSWER, answers));
InsertParam insertParam = InsertParam.newBuilder().withCollectionName(MilvusConstant.COLLECTION_NAME)
.withFields(fields).build();
R insert = milvusServiceClient.insert(insertParam);
log.info(MilvusConstant.COLLECTION_NAME + "是否成功插入数据——>>" + insert.getStatus());
}
@Override
public String getAnswer(String question) throws IOException {
List vector = milvusVectorsHandler.createVectors(Arrays.asList(question));
List> vectors = Arrays.asList(vector);
String collectionName = MilvusConstant.COLLECTION_NAME;
String partitionName = MilvusConstant.PARTITION_NAME;
String fieldName = MilvusConstant.Field.VECTOR;
SearchSimilarity searchSimilarity = searchSimilarity(vectors, collectionName, partitionName, fieldName, null);
return searchSimilarity.getData();
}
此处使用Milvus自带的IP算法来实现相似度的处理,实际运用应使用余弦相似度算法来实现相似度的处理。