随着国内外大模型热度的兴起,依托于大模型的智能化,传统的人机交互已经不能满足人们交互的需求。而结合语音和大模型的交互拜托传统互联网获取知识的文字限制,用语音也可以轻松获取想要的知识和思路。
一、大模型智能语音交互调用实现思路
唤醒的持久运行--->合成能力加持(唤醒成功后语音答复:主人 我在)--->调用在线或离线听写能力(建议用讯飞在线效果好)--->识别用户说的语音成文字后发给大模型--->建议调用讯飞星火认知大模型--->获取大模型答案后调用语音合成(合成在线离线均可)进行答案输出。
这样就顺利实现了用纯语音与大模型进行交互!
难点:唤醒+听写同时读取麦克风音频的节奏控制
持续语音交互调用大模型效果图:
二、离线环境常量定义
package com.day.config;
import com.sun.jna.ptr.IntByReference;
import javax.sound.sampled.*;
import java.io.ByteArrayOutputStream;
import java.io.FileInputStream;
public class Constants {
// 构造16K 16BIT 单声道音频
public static final String APPID = "5e11538f"; // APPID
public static final String WORK_DIR = "src/main/resources";
// 1、唤醒相关 ssb_param,一定注意IVW_SSB_PARAMS的fo|xxx资源的路径,xxx取值是指WORK_DIR目录下/msc/xxx xxx是以后的路径开始拼接的!!!!!!!!!!!
public static final AudioFormat IVW_ASR_AUDIO_FORMAT = new AudioFormat(16000F, 16, 1, true, false);
public static final String IVW_DLL_PATH = "src/main/resources/ivw_msc_x64.dll"; // windows动态库路径
public static final String IVW_LOGIN_PARAMS = "appid = " + APPID + ", work_dir = " + WORK_DIR;
public static final String IVW_SSB_PARAMS = "ivw_threshold=0:1500,sst=wakeup,ivw_shot_word=1,ivw_res_path =fo|res/ivw/wakeupresource.jet";
public static IntByReference IVW_ERROR_CODE = new IntByReference(-100);
public static Integer IVW_FRAME_SIZE = 6400; // 一定要每200ms写10帧,否则会出现唤醒一段时间后无法唤醒的问题,一帧的大小为640B,其他大小可能导致无法唤醒。
public static Integer IVW_AUDIO_STATUS = 1;
public static DataLine.Info IVW_ASR_DATA_LINE_INFO = new DataLine.Info(TargetDataLine.class, IVW_ASR_AUDIO_FORMAT);
public static TargetDataLine IVW_ASR_TARGET_DATA_LINE; // 录音
static {
try {
IVW_ASR_TARGET_DATA_LINE = (TargetDataLine) AudioSystem.getLine(IVW_ASR_DATA_LINE_INFO);
} catch (LineUnavailableException e) {
e.printStackTrace();
}
}
// 2、合成相关
public static final AudioFormat TTS_AUDIO_FORMAT = new AudioFormat(16000F, 16, 1, true, false);
public static final String TTS_DLL_PATH = "src/main/resources/tts_msc_x64.dll"; // windows动态库路径
public static final String TTS_LOGIN_PARAMS = "appid = " + APPID + ", work_dir = " + WORK_DIR;
public static final String TTS_SESSION_BEGIN_PARAMS = "engine_type = local, voice_name = xiaoyuan, text_encoding = UTF8," + " tts_res_path = fo|res/tts/xiaoyuan.jet;fo|res/tts/common.jet, sample_rate = 16000, speed = 50, volume = 50, pitch = 50, rdn = 2";
public static IntByReference TTS_ERROR_CODE = new IntByReference(-100);
public static IntByReference TTS_AUDIO_LEN = new IntByReference(-100);
public static IntByReference TTS_SYNTH_STATUS = new IntByReference(-100);
public static String TTS_TEXT; // 合成文本
public static Integer TTS_TOTAL_AUDIO_LENGTH; // 合成音频长度
public static ByteArrayOutputStream TTS_BYTE_ARRAY_OUTPUT_STREAM; // 合成音频流
public static DataLine.Info TTS_DATA_LINE_INFO = new DataLine.Info(SourceDataLine.class, TTS_AUDIO_FORMAT, AudioSystem.NOT_SPECIFIED);
public static SourceDataLine TTS_SOURCE_DATA_LINE; // 播放
static {
try {
TTS_SOURCE_DATA_LINE = (SourceDataLine) AudioSystem.getLine(Constants.TTS_DATA_LINE_INFO);
} catch (LineUnavailableException e) {
e.printStackTrace();
}
}
public static final String YELLOW_BACKGROUND = "\u001B[43m"; // ANSI code for yellow background
public static final String RESET = "\u001B[0m"; // ANSI code to reset to default
}
三、唤醒+合成代码
package com.day;
import com.day.config.Constants;
import com.day.service.IvwService;
import com.day.service.TtsService;
import com.day.service.imp.IvwCallback;
import com.sun.jna.Pointer;
import javax.sound.sampled.*;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
// 主函数入口
public class AIMain {
public static boolean ttsFlag = false;
public static boolean ivwFlag = false;
public static byte[] audioDataByteArray;
public static int len;
public static void main(String[] args) throws Exception {
// 调用流程:唤醒--->
// System.out.println(Constants.yellowBackground + "呼叫大飞" + Constants.reset);
// 以线程的方式启动唤醒
MyThread myThread = new MyThread();
myThread.start();
}
static class MyThread extends Thread {
public void run() {
startIvw();
}
}
// 1、唤醒调用
public static void startIvw() {
Integer ret = IvwService.INSTANCE.MSPLogin(null, null, Constants.IVW_LOGIN_PARAMS); // 登录
if (ret != 0) {
System.out.println("唤醒登录失败...:" + ret);
}
String sessionId = IvwService.INSTANCE.QIVWSessionBegin(null, Constants.IVW_SSB_PARAMS, Constants.IVW_ERROR_CODE); // 开启会话
if (Constants.IVW_ERROR_CODE.getValue() != 0) {
System.out.println("开启唤醒会话失败...:" + Constants.IVW_ERROR_CODE.getValue());
}
ret = IvwService.INSTANCE.QIVWRegisterNotify(sessionId, new IvwCallback(), null); // 注册唤醒回调函数
if (ret != 0) {
System.out.println("注册唤醒回调函数失败...:" + ret);
}
try {
while (true) {
// System.err.println("唤醒监听中");
Constants.IVW_ASR_TARGET_DATA_LINE.open(Constants.IVW_ASR_AUDIO_FORMAT);
Constants.IVW_ASR_TARGET_DATA_LINE.start();
audioDataByteArray = new byte[Constants.IVW_FRAME_SIZE];
len = new AudioInputStream(Constants.IVW_ASR_TARGET_DATA_LINE).read(audioDataByteArray);
if (len == -1) { // 调用麦克风时候,这段将不会被执行...
Constants.IVW_AUDIO_STATUS = 4;
ret = IvwService.INSTANCE.QIVWAudioWrite(sessionId, "".getBytes(), 0, Constants.IVW_AUDIO_STATUS);
System.out.println("最后一帧返回的错误码:" + ret + ",即将执行退出...");
break; //文件读完,跳出循环
} else {
// 反复调用QIVWAudioWrite写音频方法,直到音频写完为止!!!!!!!!!!!!
ret = IvwService.INSTANCE.QIVWAudioWrite(sessionId, audioDataByteArray, len, Constants.IVW_AUDIO_STATUS);
// System.out.println("写入音频中");
}
Constants.IVW_AUDIO_STATUS = 2; // 中间帧
if (ret != 0) {
System.err.println("唤醒音频写入失败...:" + ret);
}
Thread.sleep(200); // 模拟人说话时间间隙,10帧的音频200ms写入一次
if (ivwFlag) {
IvwService.INSTANCE.QIVWSessionEnd(sessionId, "");
IvwService.INSTANCE.MSPLogout();
Constants.IVW_ASR_TARGET_DATA_LINE.stop();
Constants.IVW_ASR_TARGET_DATA_LINE.close();
ivwFlag = false;
break;
}
// System.err.println("唤醒监听中");
}
startIvw();
} catch (Exception e) {
e.printStackTrace();
}
}
// 2、合成调用
public static void startTts(String ttsText) {
if (!AIMain.ttsFlag) {
ttsFlag = true;
Constants.TTS_TEXT = ttsText;
Constants.TTS_TOTAL_AUDIO_LENGTH = 0;
Integer ret = TtsService.INSTANCE.MSPLogin(null, null, Constants.TTS_LOGIN_PARAMS); // 登录
if (ret != 0) {
System.out.println("合成登录失败...:" + ret);
}
String session_id = TtsService.INSTANCE.QTTSSessionBegin(Constants.TTS_SESSION_BEGIN_PARAMS, Constants.TTS_ERROR_CODE); // 开启合成会话
if (Constants.TTS_ERROR_CODE.getValue() != 0) {
System.out.println("合成开启会话失败...:" + Constants.TTS_ERROR_CODE.getValue());
}
ret = TtsService.INSTANCE.QTTSTextPut(session_id, Constants.TTS_TEXT, Constants.TTS_TEXT.getBytes().length, null); // 正式合成
if (ret != 0) {
System.out.println("合成音频失败...:" + ret);
}
try { //实时播放
Constants.TTS_SOURCE_DATA_LINE.open(Constants.TTS_AUDIO_FORMAT);
Constants.TTS_SOURCE_DATA_LINE.start();
} catch (Exception e) {
e.printStackTrace();
}
while (true) {
Pointer audioPointer = TtsService.INSTANCE.QTTSAudioGet(session_id, Constants.TTS_AUDIO_LEN, Constants.TTS_SYNTH_STATUS, Constants.TTS_ERROR_CODE); // 获取音频
byte[] audioDataByteArray = null;
if (audioPointer != null) {
audioDataByteArray = audioPointer.getByteArray(0, Constants.TTS_AUDIO_LEN.getValue());
}
if (Constants.TTS_ERROR_CODE.getValue() != 0) {
System.out.println("合成获取音频失败...+:" + Constants.TTS_ERROR_CODE);
break;
}
if (audioDataByteArray != null) {
try {
Constants.TTS_SOURCE_DATA_LINE.write(audioDataByteArray, 0, Constants.TTS_AUDIO_LEN.getValue()); //实时写音频流
} catch (Exception e) {
e.printStackTrace();
}
Constants.TTS_TOTAL_AUDIO_LENGTH = Constants.TTS_TOTAL_AUDIO_LENGTH + Constants.TTS_AUDIO_LEN.getValue(); //计算总音频长度,用来生成音频文件
}
if (Constants.TTS_SYNTH_STATUS.getValue() == 2) {
// 说明音频已经取完,退出本次循环
try {
// Constants.TTS_SOURCE_DATA_LINE.drain();
// Constants.TTS_SOURCE_DATA_LINE.close();
} catch (Exception e) {
e.printStackTrace();
}
break;
}
}
ret = TtsService.INSTANCE.QTTSSessionEnd(session_id, "正常退出"); //结束会话
if (ret != 0) {
System.out.println("合成结束会话失败...:" + ret);
}
ret = TtsService.INSTANCE.MSPLogout(); // 退出
if (ret != 0) {
System.out.println("合成退出失败...:" + ret);
}
} else {
Constants.TTS_SOURCE_DATA_LINE.stop();
Constants.TTS_SOURCE_DATA_LINE.close();
}
AIMain.ttsFlag = false;
}
}
唤醒+合成库加载
package com.day.service;
import com.day.config.Constants;
import com.day.service.imp.IvwCallback;
import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.ptr.IntByReference;
public interface IvwService extends Library {
/**
* 重点:
* 1.char * 对应 String
* 2.int * 对应 IntByReference
* 3.void * 对应 Pointer或byte[]
* 4.int 对应 int
* 5.无参 对应 无参
* 6.回调函数 对应 根据文档自定义回调函数,实现接口Callback
*/
//加载dll动态库并实例化,从而使用其内部的方法
IvwService INSTANCE = Native.loadLibrary(Constants.IVW_DLL_PATH, IvwService.class);
//定义登录方法 MSPLogin(const char *usr, const char *pwd, const char *params)
public Integer MSPLogin(String usr, String pwd, String params);
//定义开始方法 QIVWSessionbegin(const char *grammarList, const char *params, int *errorCode)
public String QIVWSessionBegin(String grammarList, String params, IntByReference errorCode);
//定义写音频方法 QIVWAudioWrite(const char *sessionID, const void *audioData, unsigned int audioLen, int audioStatus)
public Integer QIVWAudioWrite(String sessionID, byte[] audioData, int audioLen, int audioStatus);
//定义结束方法 QIVWSessionEnd(const char *sessionID, const char *hints)
public Integer QIVWSessionEnd(String sessionID, String hints);
//定义获取结果方法 QIVWRegisterNotify(const char *sessionID, ivw_ntf_handler msgProcCb, void *userData)
public Integer QIVWRegisterNotify(String sessionID, IvwCallback ivwCallback, byte[] userData);
//定义退出方法 唤醒一般不用退出
public Integer MSPLogout();
}
package com.day.service;
import com.day.config.Constants;
import com.sun.jna.Library;
import com.sun.jna.Native;
import com.sun.jna.Pointer;
import com.sun.jna.ptr.IntByReference;
public interface TtsService extends Library {
/**
* 重点:
* 1.char * 对应 String
* 2.int * 对应 IntByReference
* 3.void * 对应 byte[]/Pointer,回调函数里此类型需用String来对应。
* 4.int 对应 int
* 5.无参 对应 void
* 6.回调函数 对应 根据文档自定义回调函数,实现接口Callback,离线语音合成无回调
*/
//加载dll动态库并实例化,从而使用其内部的方法
TtsService INSTANCE = Native.loadLibrary(Constants.TTS_DLL_PATH, TtsService.class);
//定义登录方法
public Integer MSPLogin(String usr, String pwd, String params);
//开始一次普通离线语音合成
public String QTTSSessionBegin(String params, IntByReference errorCode);
//写入需要合成的文本
public Integer QTTSTextPut(String sessionID, String textString, int textLen, String params);
//获取离线合成的音频
public Pointer QTTSAudioGet(String sessionID, IntByReference audioLen, IntByReference synthStatus, IntByReference errorCode);
//结束本次普通离线语音合成
public Integer QTTSSessionEnd(String sessionID, String hints);
//定义退出方法
public Integer MSPLogout();
}
四、唤醒回调
package com.day.service.imp;
import com.day.AIMain;
import com.day.ability.IatMic;
import com.day.config.Constants;
import com.sun.jna.Callback;
import javax.sound.sampled.AudioFileFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
public class IvwCallback implements Callback {
public int cb_ivw_msg_proc(String sessionID, int msg, int param1, int param2, String info, String userData) throws Exception {
System.out.println("机器人大飞:主人,您请说~");
AIMain.startTts("主人,您请说~");
// 先录音后调用听写
IatMic.iatWork();
return 0;
}
}
五、听写代码(重点是和唤醒公用一个麦克风音频流)
package com.day.ability;
import com.day.AIMain;
import com.day.config.Constants;
import com.day.service.IvwService;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import okhttp3.*;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import javax.sound.sampled.AudioInputStream;
import java.io.IOException;
import java.net.URL;
import java.nio.charset.Charset;
import java.text.SimpleDateFormat;
import java.util.*;
// 麦克风传流听写
public class IatMic extends WebSocketListener {
private static final String hostUrl = "https://iat-api.xfyun.cn/v2/iat"; //中英文,http url 不支持解析 ws/wss schema
// private static final String hostUrl = "https://iat-niche-api.xfyun.cn/v2/iat";//小语种
private static final String appid = ""; //在控制台-我的应用获取
private static final String apiSecret = ""; //在控制台-我的应用-语音听写(流式版)获取
private static final String apiKey = ""; //在控制台-我的应用-语音听写(流式版)获取
//private static final String file = "./zMusic/pcm/科大讯飞.pcm"; // 中文
public static final int StatusFirstFrame = 0;
public static final int StatusContinueFrame = 1;
public static final int StatusLastFrame = 2;
public static final Gson json = new Gson();
Decoder decoder = new Decoder();
// 开始时间
private static Date dateBegin = new Date();
// 结束时间
private static Date dateEnd = new Date();
private static final SimpleDateFormat sdf = new SimpleDateFormat("yyy-MM-dd HH:mm:ss.SSS");
static int status = 0; // 音频的状态
public static boolean IAT_FLAG = true;
public static String fileName = "";
public static void main(String[] args) throws Exception {
iatWork();
}
static class MyThread extends Thread {
public void run() {
/* // 录制用户说话
ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
long startTime1 = System.currentTimeMillis();
long endTime1 = startTime1 + 500; // 10 seconds
while (System.currentTimeMillis() < endTime1) {
System.out.print("");
}
// Step 4: Start recording
byte[] buffer = new byte[Constants.IVW_ASR_TARGET_DATA_LINE.getBufferSize() / 5]; // Adjust buffer size as needed
int bytesRead;
long startTime = System.currentTimeMillis();
long endTime = startTime + 4000; // 10 seconds
// Step 5: Loop until recording time reaches 10 seconds
while (System.currentTimeMillis() < endTime) {
bytesRead = Constants.IVW_ASR_TARGET_DATA_LINE.read(buffer, 0, buffer.length);
if (bytesRead > 0) {
outputStream.write(buffer, 0, bytesRead);
}
}
byte[] audioBytes = outputStream.toByteArray();
// Step 9: Write byte array to audio file or other destination using AudioSystem.write method
// Example: Save audioBytes to a WAV file
try {
File audioFile = new File("src/main/resources/1.wav");
AudioInputStream audioInputStream = new AudioInputStream(new ByteArrayInputStream(audioBytes), Constants.IVW_ASR_AUDIO_FORMAT, audioBytes.length / Constants.IVW_ASR_AUDIO_FORMAT.getFrameSize());
AudioSystem.write(audioInputStream, AudioFileFormat.Type.WAVE, audioFile);
} catch (IOException e) {
e.printStackTrace();
}
fileName = "src/main/resources/1.wav";*/
// 需要初始化的参数都在这里添加
IatMic.IAT_FLAG = true;
status = 0;
// 结束初始化
IatMic iatMic = new IatMic();
// 构建鉴权url
String authUrl = null;
try {
authUrl = getAuthUrl(hostUrl, apiKey, apiSecret);
} catch (Exception e) {
throw new RuntimeException(e);
}
OkHttpClient client = new OkHttpClient.Builder().build();
//将url中的 schema http://和https://分别替换为ws:// 和 wss://
String url = authUrl.toString().replace("http://", "ws://").replace("https://", "wss://");
// System.err.println(url);
Request request = new Request.Builder().url(url).build();
WebSocket webSocket = client.newWebSocket(request, iatMic);
}
}
public static void iatWork() throws Exception {
// 用线程方式启动,不影响唤醒,里面不要执行任何长时间的代码
MyThread myThread = new MyThread();
myThread.start();
}
@Override
public void onOpen(WebSocket webSocket, Response response) {
// System.out.println("建立连接成功");
System.out.println(Constants.YELLOW_BACKGROUND + "机器人正在听,您请说:" + Constants.RESET);
super.onOpen(webSocket, response);
new Thread(() -> {
//连接成功,开始发送数据
// int interval = 200;
try {
Constants.IVW_ASR_TARGET_DATA_LINE.open(Constants.IVW_ASR_AUDIO_FORMAT);
Constants.IVW_ASR_TARGET_DATA_LINE.start();
while (true) {
// System.err.println(AIMain.len + "" + AIMain.audioDataByteArray);
if (AIMain.len == -1) {
status = 2;// 标志读取完毕
}
switch (status) {
case StatusFirstFrame: // 第一帧音频status = 0
JsonObject frame = new JsonObject();
JsonObject business = new JsonObject(); //第一帧必须发送
JsonObject common = new JsonObject(); //第一帧必须发送
JsonObject data = new JsonObject(); //每一帧都要发送
// 填充common
common.addProperty("app_id", appid);
//填充business
business.addProperty("language", "zh_cn");//
//business.addProperty("language", "en_us");//英文
//business.addProperty("language", "ja_jp");//日语,在控制台可添加试用或购买
//business.addProperty("language", "ko_kr");//韩语,在控制台可添加试用或购买
//business.addProperty("language", "ru-ru");//俄语,在控制台可添加试用或购买
//business.addProperty("ptt", 1);
business.addProperty("domain", "iat");
//mandarin中文普通话 广东话cantonese
business.addProperty("accent", "mandarin");//中文方言请在控制台添加试用,添加后即展示相应参数值cantonese//mandarin
//business.addProperty("nunum", 0);
//business.addProperty("ptt", 1);//标点符号
//business.addProperty("rlang", "zh-hk"); // zh-cn :简体中文(默认值)zh-hk :繁体香港(若未授权不生效,在控制台可免费开通)
business.addProperty("vinfo", 1);
business.addProperty("dwa", "wpgs");//动态修正(若未授权不生效,在控制台可免费开通)
business.addProperty("vad_eos", 3000);
//business.addProperty("fa_nbest", true);
//business.addProperty("fa_sch", true);
//business.addProperty("vinfo", 1);
//business.addProperty("speex_size", 70);
//business.addProperty("nbest", 5);// 句子多候选(若未授权不生效,在控制台可免费开通)
//business.addProperty("wbest", 3);// 词级多候选(若未授权不生效,在控制台可免费开通)
//填充data
data.addProperty("status", StatusFirstFrame);
data.addProperty("format", "audio/L16;rate=16000");
//data.addProperty("encoding", "speex-wb");
data.addProperty("encoding", "raw");
data.addProperty("audio", Base64.getEncoder().encodeToString(Arrays.copyOf(AIMain.audioDataByteArray, AIMain.len)));
//填充frame
frame.add("common", common);
frame.add("business", business);
frame.add("data", data);
// System.out.println("即将发送第一帧数据...");
// System.err.println(frame.toString());
webSocket.send(frame.toString());
status = StatusContinueFrame; // 发送完第一帧改变status 为 1
break;
case StatusContinueFrame: //中间帧status = 1
JsonObject frame1 = new JsonObject();
JsonObject data1 = new JsonObject();
data1.addProperty("status", StatusContinueFrame);
data1.addProperty("format", "audio/L16;rate=16000");
//data1.addProperty("encoding", "speex-wb");
data1.addProperty("encoding", "raw");
String temp = Base64.getEncoder().encodeToString(Arrays.copyOf(AIMain.audioDataByteArray, AIMain.len));
data1.addProperty("audio", temp);
frame1.add("data", data1);
//System.out.println(temp);
webSocket.send(frame1.toString());
break;
}
try {
Thread.sleep(200);
if (!IAT_FLAG) {
//System.out.println("本次会话结束");
break;
}
} catch (Exception e) {
e.printStackTrace();
}
}
//说明读完了
status = StatusLastFrame;
JsonObject frame2 = new JsonObject();
JsonObject data2 = new JsonObject();
data2.addProperty("status", StatusLastFrame);
data2.addProperty("audio", "");
data2.addProperty("format", "audio/L16;rate=16000");
//data2.addProperty("encoding", "speex-wb");
data2.addProperty("encoding", "raw");
frame2.add("data", data2);
webSocket.send(frame2.toString());
// System.err.println(frame2.toString());
// System.out.println("all data is send");
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}).start();
}
@Override
public void onMessage(WebSocket webSocket, String text) {
// System.out.println(text);
super.onMessage(webSocket, text);
ResponseData resp = json.fromJson(text, ResponseData.class);
if (resp != null) {
if (resp.getCode() != 0) {
AIMain.ivwFlag = true; // 如果报错也需要恢复唤醒
System.out.println("code=>" + resp.getCode() + " error=>" + resp.getMessage() + " sid=" + resp.getSid());
System.out.println("错误码查询链接:https://www.xfyun.cn/document/error-code");
return;
}
if (resp.getData() != null) {
if (resp.getData().getResult() != null) {
Text te = resp.getData().getResult().getText();
//System.out.println(te.toString());
try {
decoder.decode(te);
dateEnd = new Date();
// System.out.println("耗时:" + (dateEnd.getTime() - dateBegin.getTime()) + "ms");
System.out.println(Constants.YELLOW_BACKGROUND + "用户说话识别中:" + decoder.toString() + Constants.RESET);
//System.err.println("中间识别JSON结果 ----" + text);
} catch (Exception e) {
e.printStackTrace();
}
}
if (resp.getData().getStatus() == 2) {
// todo resp.data.status ==2 说明数据全部返回完毕,可以关闭连接,释放资源
//System.err.println("我的getStatus() == 2");
// System.out.println("session end ");
dateEnd = new Date();
// System.out.println(sdf.format(dateBegin) + "开始");
// System.out.println(sdf.format(dateEnd) + "结束");
// System.out.println("耗时:" + (dateEnd.getTime() - dateBegin.getTime()) + "ms");
System.out.println(Constants.YELLOW_BACKGROUND + "用户说话识别最终结果:" + decoder.toString() + Constants.RESET);
AIMain.ivwFlag = true; // 恢复唤醒
// System.out.println("本次识别sid ==》" + resp.getSid());
try {
BigModelNew.doSpark(decoder.toString()); // 调用大模型回答问题!!!
} catch (Exception e) {
throw new RuntimeException(e);
}
decoder.discard();
webSocket.close(1000, "");
IatMic.IAT_FLAG = false;
// System.exit(0);
} else {
// todo 根据返回的数据处理
}
}
}
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
super.onFailure(webSocket, t, response);
try {
if (null != response) {
int code = response.code();
System.out.println("onFailure code:" + code);
System.out.println("onFailure body:" + response.body().string());
if (101 != code) {
System.out.println("connection failed");
System.exit(0);
}
}
} catch (IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {
URL url = new URL(hostUrl);
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
format.setTimeZone(TimeZone.getTimeZone("GMT"));
String date = format.format(new Date());
//String date = format.format(new Date());
//System.err.println(date);
StringBuilder builder = new StringBuilder("host: ").append(url.getHost()).append("\n").//
append("date: ").append(date).append("\n").//
append("GET ").append(url.getPath()).append(" HTTP/1.1");
//System.err.println(builder);
Charset charset = Charset.forName("UTF-8");
Mac mac = Mac.getInstance("hmacsha256");
SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(charset), "hmacsha256");
mac.init(spec);
byte[] hexDigits = mac.doFinal(builder.toString().getBytes(charset));
String sha = Base64.getEncoder().encodeToString(hexDigits);
//System.err.println(sha);
String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
//System.err.println(authorization);
HttpUrl httpUrl = HttpUrl.parse("https://" + url.getHost() + url.getPath()).newBuilder().//
addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(charset))).//
addQueryParameter("date", date).//
addQueryParameter("host", url.getHost()).//
build();
return httpUrl.toString();
}
public static class ResponseData {
private int code;
private String message;
private String sid;
private Data data;
public int getCode() {
return code;
}
public String getMessage() {
return this.message;
}
public String getSid() {
return sid;
}
public Data getData() {
return data;
}
}
public static class Data {
private int status;
private Result result;
public int getStatus() {
return status;
}
public Result getResult() {
return result;
}
}
public static class Result {
int bg;
int ed;
String pgs;
int[] rg;
int sn;
Ws[] ws;
boolean ls;
JsonObject vad;
public Text getText() {
Text text = new Text();
StringBuilder sb = new StringBuilder();
for (Ws ws : this.ws) {
sb.append(ws.cw[0].w);
}
text.sn = this.sn;
text.text = sb.toString();
text.sn = this.sn;
text.rg = this.rg;
text.pgs = this.pgs;
text.bg = this.bg;
text.ed = this.ed;
text.ls = this.ls;
text.vad = this.vad == null ? null : this.vad;
return text;
}
}
public static class Ws {
Cw[] cw;
int bg;
int ed;
}
public static class Cw {
int sc;
String w;
}
public static class Text {
int sn;
int bg;
int ed;
String text;
String pgs;
int[] rg;
boolean deleted;
boolean ls;
JsonObject vad;
@Override
public String toString() {
return "Text{" + "bg=" + bg + ", ed=" + ed + ", ls=" + ls + ", sn=" + sn + ", text='" + text + '\'' + ", pgs=" + pgs + ", rg=" + Arrays.toString(rg) + ", deleted=" + deleted + ", vad=" + (vad == null ? "null" : vad.getAsJsonArray("ws").toString()) + '}';
}
}
//解析返回数据,仅供参考
public static class Decoder {
private Text[] texts;
private int defc = 10;
public Decoder() {
this.texts = new Text[this.defc];
}
public synchronized void decode(Text text) {
if (text.sn >= this.defc) {
this.resize();
}
if ("rpl".equals(text.pgs)) {
for (int i = text.rg[0]; i <= text.rg[1]; i++) {
this.texts[i].deleted = true;
}
}
this.texts[text.sn] = text;
}
public String toString() {
StringBuilder sb = new StringBuilder();
for (Text t : this.texts) {
if (t != null && !t.deleted) {
sb.append(t.text);
}
}
return sb.toString();
}
public void resize() {
int oc = this.defc;
this.defc <<= 1;
Text[] old = this.texts;
this.texts = new Text[this.defc];
for (int i = 0; i < oc; i++) {
this.texts[i] = old[i];
}
}
public void discard() {
for (int i = 0; i < this.texts.length; i++) {
this.texts[i] = null;
}
}
}
}
六、大模型调用代码
package com.day.ability;
import com.day.AIMain;
import com.day.util.MyUtil;
import com.google.gson.Gson;
import okhttp3.HttpUrl;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.text.SimpleDateFormat;
import java.util.*;
// 主函数入口
public class BigModelNew {
public static final String hostUrl = "https://spark-api.xf-yun.com/v3/completions";
private static final String appid = "";
private static final String apiSecret = "";
private static final String apiKey = "";
private static final Gson gson = new Gson();
public static void main(String[] args) throws Exception {
doSpark("我想吃鸡。");
}
public static void doSpark(String content) throws Exception {
MyThread myThread = new MyThread(content);
myThread.start();
}
static class MyThread extends Thread {
String content;
public MyThread(String content) {
this.content = content;
}
public void run() {
String authUrl = null;
try {
authUrl = getAuthUrl(hostUrl, apiKey, apiSecret);
} catch (Exception e) {
throw new RuntimeException(e);
}
// URL地址正确
// System.err.println(authUrl);
String json = "{\n" + " \"app_id\": \"" + appid + "\",\n" + " \"uid\": \"" + UUID.randomUUID().toString().substring(0, 10) + "\",\n" + " \"domain\": \"generalv2\",\n" + " \"temperature\": 0.5,\n" + " \"max_tokens\": 4096,\n" + " \"auditing\": \"default\",\n" + " \"stream\": true,\n" + " \"messages\": [\n" + " {\n" + " \"role\": \"user\",\n" + " \"content\": \"" + content + "\"\n" + " }\n" + " ]\n" + "}";
// 发起Post请求
String res = MyUtil.doPostJson(authUrl, null, json);
String finalRes = "";
String[] resArray = res.split("\n");
for (int i = 0; i < resArray.length; i++) {
if (resArray[i].contains("data:")) {
String jsonStr = resArray[i].replace("data:", "");
BigJsonParse bigJsonParse = gson.fromJson(jsonStr, BigJsonParse.class);
List choicesList = bigJsonParse.choices;
if (choicesList != null && choicesList.size() > 0) {
for (Choices choice : choicesList) {
finalRes = finalRes + choice.content;
}
} else {
finalRes = "您好,我是讯飞星火认知大模型";
}
}
}
System.out.println(finalRes);
String temp = finalRes.replaceAll("\r\n", "").replaceAll("\n", "");
System.out.println("*****************************************************************************************************");
AIMain.startTts(temp);
}
}
// 鉴权方法
public static String getAuthUrl(String hostUrl, String apiKey, String apiSecret) throws Exception {
URL url = new URL(hostUrl);
// 时间
SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss z", Locale.US);
format.setTimeZone(TimeZone.getTimeZone("GMT"));
String date = format.format(new Date());
// date="Thu, 12 Oct 2023 03:05:28 GMT";
// 拼接
String preStr = "host: " + url.getHost() + "\n" + "date: " + date + "\n" + "POST " + url.getPath() + " HTTP/1.1";
// System.err.println(preStr);
// SHA256加密
Mac mac = Mac.getInstance("hmacsha256");
SecretKeySpec spec = new SecretKeySpec(apiSecret.getBytes(StandardCharsets.UTF_8), "hmacsha256");
mac.init(spec);
byte[] hexDigits = mac.doFinal(preStr.getBytes(StandardCharsets.UTF_8));
// Base64加密
String sha = Base64.getEncoder().encodeToString(hexDigits);
// System.err.println(sha);
// 拼接
String authorization = String.format("api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"", apiKey, "hmac-sha256", "host date request-line", sha);
// 拼接地址
HttpUrl httpUrl = Objects.requireNonNull(HttpUrl.parse("https://" + url.getHost() + url.getPath())).newBuilder().//
addQueryParameter("authorization", Base64.getEncoder().encodeToString(authorization.getBytes(StandardCharsets.UTF_8))).//
addQueryParameter("date", date).//
addQueryParameter("host", url.getHost()).//
build();
// System.err.println(httpUrl.toString());
return httpUrl.toString();
}
}
// JSON
class BigJsonParse {
List choices;
}
class Choices {
String content;
}
七、HTTP PSOT请求代码
package com.day.util;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpUriRequest;
import org.apache.http.client.methods.RequestBuilder;
import org.apache.http.entity.ContentType;
import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Map;
public class MyUtil {
/**
* 1.发起post请求
*/
public static String doPostJson(String url, Map urlParams, String json) {
CloseableHttpClient closeableHttpClient = HttpClients.createDefault();
CloseableHttpResponse closeableHttpResponse = null;
String resultString = "";
try {
// 创建Http Post请求
String asciiUrl = URI.create(url).toASCIIString();
RequestBuilder builder = RequestBuilder.post(asciiUrl);
builder.setCharset(StandardCharsets.UTF_8);
if (urlParams != null) {
for (Map.Entry entry : urlParams.entrySet()) {
builder.addParameter(entry.getKey(), entry.getValue());
}
}
// 创建请求内容
StringEntity entity = new StringEntity(json, ContentType.APPLICATION_JSON);
builder.setEntity(entity);
HttpUriRequest request = builder.build();
// 执行http请求
closeableHttpResponse = closeableHttpClient.execute(request);
resultString = EntityUtils.toString(closeableHttpResponse.getEntity(), StandardCharsets.UTF_8);
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
if (closeableHttpResponse != null) {
closeableHttpResponse.close();
}
if (closeableHttpClient != null) {
closeableHttpClient.close();
}
} catch (Exception e) {
e.printStackTrace();
}
}
return resultString;
}
}
八、整体项目结构目录