基于Milvus和BERT搭建AI智能问答系统(Java+Python实现)

添加Maven依赖


	io.milvus
	milvus-sdk-java
	2.2.3

获取csv原始数据集dataset

    public List getData() {
        try {
            InputStream inputStream = ManagerAiController.class.getClassLoader().getResourceAsStream("questions/lancode.csv");
            byte[] bytes = inputStream.readAllBytes();
            String csvString = new String(bytes, "gb2312");
            Reader reader = new StringReader(csvString);
            CSVParser parser = CSVFormat.DEFAULT.withFirstRecordAsHeader().parse(reader);
            List dataList = Lists.newArrayList();
            for (CSVRecord record : parser) {
                MilvusParam data = new MilvusParam();
                data.setQuestion(record.get("question"));
                data.setAnswer(record.get("answer"));
                dataList.add(data);
            }
            return dataList;
        } catch (Exception e) {
            return Lists.newArrayList();
        }
    }

使用BERT生成特征向量

    public List createVectors(List questions) throws IOException {
        StringJoiner joiner = new StringJoiner(", ", "[", "]");
        for (String question : questions) {
            joiner.add("'" + question + "'");
        }
        String questionList = joiner.toString();

        String prefix = "# -*- coding: utf-8 -*-\n";
        String nodeSnippet = "from bert_serving.client import BertClient\n" +
                "bc = BertClient(ip='localhost', check_version=False, port=5555, port_out=5556, check_length=False, timeout=10000)\n" +
                "sentences_list = " + questionList + "\n" +
                "output = bc.encode(sentences_list, show_tokens=False, is_tokenized=False, blocking=True)\n" +
                "print(output.tolist())";
        nodeSnippet = prefix + nodeSnippet;

        String[] cmd = {"python", "-c", nodeSnippet};
        ProcessBuilder pb = new ProcessBuilder(cmd);
        pb.redirectInput(ProcessBuilder.Redirect.PIPE);
        pb.redirectOutput(ProcessBuilder.Redirect.PIPE);
        pb.redirectError(ProcessBuilder.Redirect.PIPE);
        pb.environment().put("PYTHONIOENCODING", "UTF-8");

        Process process = pb.start();
        process.getOutputStream().close();

        String result = result(process);

        Gson gson = new Gson();
        Type type = new com.google.gson.reflect.TypeToken>>() {
        }.getType();
        List> resultList = gson.fromJson(result, type);
        List floatList = resultList.get(0);
        log.info("size:" + floatList.size());
        log.info("content:" + floatList);

        error(process);
        return floatList;
    }

将dataset加载到向量数据库Milvus

    @Override
    public void createCollection(List milvusList) throws IOException {
        // 创建集合
        creatCollection();
        // 创建分区
        createPartition(MilvusConstant.COLLECTION_NAME, MilvusConstant.PARTITION_NAME);
        // 将数据加载到内存需要创建索引
        createIndex(MilvusConstant.COLLECTION_NAME, MilvusConstant.Field.VECTOR);
        // 保存数据到milvus
        insert(milvusList);
        // 将Milvus数据加载到内存
        loadCollection();
    }

    public void creatCollection() {
        FieldType idField = FieldType.newBuilder().withName(MilvusConstant.Field.ID).withDataType(DataType.Int64)
                .withPrimaryKey(true).withAutoID(true).build();

        FieldType vectorField = FieldType.newBuilder().withName(MilvusConstant.Field.VECTOR)
                .withDataType(DataType.FloatVector).withDimension(MilvusConstant.FEATURE_DIM).build();

        FieldType questionField = FieldType.newBuilder().withName(MilvusConstant.Field.QUESTION)
                .withDataType(DataType.VarChar).withMaxLength(MilvusConstant.QUESTION_LENGTH).build();

        FieldType answerField = FieldType.newBuilder().withName(MilvusConstant.Field.ANSWER)
                .withDataType(DataType.VarChar).withMaxLength(MilvusConstant.ANSWER_LENGTH).build();

        CreateCollectionParam collection = CreateCollectionParam.newBuilder()
                .withCollectionName(MilvusConstant.COLLECTION_NAME).withDescription("question answer search")
                .withShardsNum(2).addFieldType(idField).addFieldType(vectorField).addFieldType(questionField).addFieldType(answerField).build();

        R response = milvusServiceClient.createCollection(collection);

        log.info(MilvusConstant.COLLECTION_NAME + "是否成功创建集合——>>" + response.getStatus());
    }

    public Integer createPartition(String collectionName, String partitionName) {
        R response = milvusServiceClient.createPartition(CreatePartitionParam.newBuilder()
                .withCollectionName(collectionName)
                .withPartitionName(partitionName)
                .build());
        return response.getStatus();
    }

    public R createIndex(String collectionName, String fieldName) {
        R response = milvusServiceClient.createIndex(CreateIndexParam.newBuilder()
                .withCollectionName(collectionName)
                .withFieldName(fieldName)
                .withIndexName(fieldName)
                .withIndexType(IndexType.IVF_FLAT)
                .withMetricType(MetricType.IP)
                .withExtraParam("{\"nlist\":16384}")
                .withSyncMode(Boolean.FALSE)
                .build());
        log.info("createIndex-------------------->{}", response.toString());

        R indexResp = milvusServiceClient.describeIndex(
                DescribeIndexParam.newBuilder()
                        .withCollectionName(collectionName)
                        .withIndexName(fieldName)
                        .build());
        log.info("after createIndex-------------------->{}", indexResp.getStatus());
        return response;
    }

    public void loadCollection() {
        R response = milvusServiceClient.loadPartitions(LoadPartitionsParam.newBuilder()
                .withCollectionName(MilvusConstant.COLLECTION_NAME)
                .withPartitionNames(Lists.newArrayList(MilvusConstant.PARTITION_NAME)).build());
        log.info("Milvus数据加载到内存状态:{}", response.getStatus());
    }

    public void insert(List milvusParamList) throws IOException {
        List questions = new ArrayList<>();
        List answers = new ArrayList<>();
        List> vectors = new ArrayList<>();
        for (MilvusParam vo : milvusParamList) {
            vectors.add(milvusVectorsHandler.createVectors(Arrays.asList(vo.getQuestion())));
            questions.add(vo.getQuestion());
            answers.add(vo.getAnswer());
        }

        List fields = new ArrayList<>();
        fields.add(new InsertParam.Field(MilvusConstant.Field.VECTOR, vectors));
        fields.add(new InsertParam.Field(MilvusConstant.Field.QUESTION, questions));
        fields.add(new InsertParam.Field(MilvusConstant.Field.ANSWER, answers));

        InsertParam insertParam = InsertParam.newBuilder().withCollectionName(MilvusConstant.COLLECTION_NAME)
                .withFields(fields).build();
        R insert = milvusServiceClient.insert(insertParam);

        log.info(MilvusConstant.COLLECTION_NAME + "是否成功插入数据——>>" + insert.getStatus());
    }

从Milvus中获取question回答

    @Override
    public String getAnswer(String question) throws IOException {
        List vector = milvusVectorsHandler.createVectors(Arrays.asList(question));
        List> vectors = Arrays.asList(vector);
        String collectionName = MilvusConstant.COLLECTION_NAME;
        String partitionName = MilvusConstant.PARTITION_NAME;
        String fieldName = MilvusConstant.Field.VECTOR;
        SearchSimilarity searchSimilarity = searchSimilarity(vectors, collectionName, partitionName, fieldName, null);
        return searchSimilarity.getData();
    }

此处使用Milvus自带的IP算法来实现相似度的处理,实际运用应使用余弦相似度算法来实现相似度的处理。

你可能感兴趣的:(人工智能,milvus,bert,python,人工智能,自然语言处理)