实现在Unity内部的大模型访问,我也是第一次接触Unity中通过大模型url访问。此博客面向新手,旨在给大家简单理解大模型POST和GET过程,还有实现简单的大模型访问。
参考博客:什么是chatGPT?Unity结合OpenAI官方api实现类似chatGPT的AI聊天机器人
附带源码地址:OpenAIChatRobotMaster: 使用unity实现的基于OpenAI官方api的AI聊天机器人示例
参考的博客主要用于访问CHATGPT,但是我目前需求是访问自己的大模型URL,其中碰到的问题以及代码的一些详细解读。
目录
一、效果展示
二、源码POST解读
三、源码修改
四、代码评价
五、整个代码
UI改了一下,具体效果如上,这是使用了我们实验室自己部署的vicuna-13b大模型
首先原博客的模型是chatgpt的:text-davinci-003模型,模型的请求体和响应体如下:
所以原博客的unity代码中有这样的封装:
[System.Serializable]public class PostData{
public string model;
public string prompt;
public int max_tokens;
public float temperature;
public int top_p;
public float frequency_penalty;
public float presence_penalty;
public string stop;
}
[System.Serializable]public class TextCallback{
public string id;
public string created;
public string model;
public List choices;
[System.Serializable]public class TextSample{
public string text;
public string index;
public string finish_reason;
}
}
代码逻辑我做了一个图,大家可以看看,可以方便理解源代码:
逻辑概述:输入框的文本信息_msg一方面渲染到了聊天框m_PostChatPrefab,另一方面被封装到了_postData类里的prompt中。将json信息传递到LLM中,最后返回的_msg转json格式得到_textback格式。我们需要得到的是其中的choices[0]。将得到的choices[0]渲染到聊天框m_RobotChatPrefab中。
具体修改部分主要是请求体和响应体的格式。源码中post过程没有错误答应部分,我自己中途打印了一些中间变量,方便查错:
本次以vicuna-13b大模型为例,官方文档没有请求体和响应体的格式,所以通过postman来查看的格式,如下:
参照上面的内容修改代码如下:
[System.Serializable]public class PostData{
public string model;
public List messages;
[System.Serializable]
public class Messages
{
public string role;
public string content;
}
public int web;
public int account_id;
public int conversation_id;
public float temperature;
public int top_p;
public int n;
public int max_tokens;
public string stop;
public bool stream;
public float frequency_penalty;
public float presence_penalty;
public string user;
}
[System.Serializable]public class TextCallback{
public string id;
public string created;
public string model;
public List choices;
[System.Serializable]public class TextSample
{
public string index;
public Messages message;
[System.Serializable]
public class Messages
{
public string role;
public string content;
}
public string finish_reason;
}
}
其中的_postData传参代码修改如下:
PostData _postData = new PostData{
model = m_PostDataSetting.model,
messages = new List
{
new PostData.Messages
{
role = "user",
content = _postWord
}
},
web = m_PostDataSetting.web,
account_id = m_PostDataSetting.account_id,
conversation_id = m_PostDataSetting.conversation_id,
temperature = m_PostDataSetting.temperature,
top_p = m_PostDataSetting.top_p,
n = m_PostDataSetting.n,
max_tokens = m_PostDataSetting.max_tokens,
stop = "string",
stream= false,
frequency_penalty = m_PostDataSetting.frequency_penalty,
presence_penalty = m_PostDataSetting.presence_penalty,
user = "string"
};
心得:
其中报错内容有1.打印出的_jsonText没有message
2.TextCallback.TextSample.Messages定义不正确
原因:
string _jsonText = JsonUtility.ToJson (_postData);中:
JsonUtility.ToJson()
方法具有一些限制,它只能序列化Unity支持的类型,并且不能序列化嵌套的自定义类型(如Messages
类)。所以在新加的public class Messages前面添加[System.Serializable],就可以序列化了。
代码方面通俗易懂,但是因为后续工作需求,代码还有许多功能需要增加:
1.此代码将每一次的对话都直接渲染到聊天框中,没有在内部进行存储,导致的结果就是不能进行多轮对话,后续我将朝这个方向进行改进。
2.此对话响应方式是响应全部结束后才渲染出来,不能一个字一个字的流式响应。
针对这些需求,后续会对代码进行修改。
private string m_ApiUrl = "http://*********/completion";
//配置参数,用于存储聊天界面的一些配置信息。
[SerializeField] private PostData m_PostDataSetting;
//输入的信息,用于获取用户输入的聊天内容。
[SerializeField] private InputField m_InputWord;
//聊天文本放置的层,用于存储聊天气泡的位置信息。
[SerializeField] private RectTransform m_rootTrans;
//发送聊天气泡,用于存储用户发送的聊天气泡的预制体。
[SerializeField] private ChatPrefab m_PostChatPrefab;
//回复的聊天气泡,用于存储机器人回复的聊天气泡的预制体。
[SerializeField] private ChatPrefab m_RobotChatPrefab;
//滚动条,用于控制聊天界面的滚动。
[SerializeField] private ScrollRect m_ScroTectObject;
///
/// 发送信息UI
///
public void SendData()
{
if (m_InputWord.text.Equals(""))
return;
//将输入框中的文本作为消息进行处理,
string _msg = m_InputWord.text;
// 以m_PostChatPrefab预制体为模板,生成聊天记录
ChatPrefab _chat = Instantiate(m_PostChatPrefab, m_rootTrans.transform);
_chat.SetText(_msg);
//重新计算容器尺寸
LayoutRebuilder.ForceRebuildLayoutImmediate(m_rootTrans);
//使用协程(TurnToLastLine())确保聊天框始终显示最新的聊天记录。
StartCoroutine(TurnToLastLine());
//获取发送的数据并将其传递给回调函数(CallBack)。
StartCoroutine(GetPostData(_msg, CallBack));
//清空输入框文本
m_InputWord.text = "";
}
///
/// AI回复的信息UI
///
///
private void CallBack(string _callback)
{
//去除字符串两侧的空格
_callback = _callback.Trim();
//将该字符串传递给 ChatPrefab 类的实例变量 _chat。
ChatPrefab _chat = Instantiate(m_RobotChatPrefab, m_rootTrans.transform);
_chat.SetText(_callback);
//重新计算容器尺寸
LayoutRebuilder.ForceRebuildLayoutImmediate(m_rootTrans);
//将页面滚动到最后一行,
StartCoroutine(TurnToLastLine());
}
///
///UI协程函数, 将文本框滚动到最后一行消息的位置。
///
///
private IEnumerator TurnToLastLine()
{
yield return new WaitForEndOfFrame();
//滚动到最近的消息
m_ScroTectObject.verticalNormalizedPosition = 0;
}
///
/// 设置AI模型类型model
///
///
public void SetAIModel(Toggle _modelType)
{
if (_modelType.isOn)
{
m_PostDataSetting.model = _modelType.name;
}
}
//---------------------------------------------------------------------------------------------------------------
///
/// 用于存储向AI模型发送的参数数据。
///
[System.Serializable]
public class PostData
{
public string model;
public List messages;
[System.Serializable]
public class Messages
{
public string role;
public string content;
}
public int web;
public int account_id;
public int conversation_id;
public float temperature;
public int top_p;
public int n;
public int max_tokens;
public string stop;
public bool stream;
public float frequency_penalty;
public float presence_penalty;
//public string user;
}
///
/// 向AI模型发送数据
///
///
///
///
private IEnumerator GetPostData(string _postWord, System.Action _callback)
{
//UnityWebRequest发送POST请求,接口:m_ApiUrl
var request = new UnityWebRequest(m_ApiUrl, "POST");
PostData _postData = new PostData
{
model = m_PostDataSetting.model,
messages = new List{
new PostData.Messages
{
role = "user",
content = _postWord
}
},
web = m_PostDataSetting.web,
account_id = m_PostDataSetting.account_id,
conversation_id = m_PostDataSetting.conversation_id,
temperature = m_PostDataSetting.temperature,
top_p = m_PostDataSetting.top_p,
n = m_PostDataSetting.n,
max_tokens = m_PostDataSetting.max_tokens,
stop = m_PostDataSetting.stop,
stream = false,
frequency_penalty = m_PostDataSetting.frequency_penalty,
presence_penalty = m_PostDataSetting.presence_penalty,
//user = m_PostDataSetting.user
};
//将 _postData 转换成 JSON 格式的字符串
string _jsonText = JsonUtility.ToJson(_postData);
byte[] data = System.Text.Encoding.UTF8.GetBytes(_jsonText);
//请求数据data上传
request.uploadHandler = (UploadHandler)new UploadHandlerRaw(data);
//设置请求的下载处理器DownloadHandlerBuffer,,,返回的数据存储在缓存区
request.downloadHandler = (DownloadHandler)new DownloadHandlerBuffer();
//设置请求头。。告诉服务器上传的数据为 JSON 格式。
request.SetRequestHeader("Content-Type", "application/json");
//request.SetRequestHeader("Authorization",string.Format("Bearer {0}",m_OpenAI_Key));
//异步发送请求并等待响应
yield return request.SendWebRequest();
Debug.Log("Response Code: " + request.responseCode);
if (request.responseCode == 200)
{
string _msg = request.downloadHandler.text;
Debug.Log(" _msg: " + _msg);
//将_msg转化为TextCallback数据结构
if (!string.IsNullOrEmpty(_msg))
{
TextCallback _textback = JsonUtility.FromJson(_msg);
Debug.Log("_textback: " + _textback);
Debug.Log("_textback.choices[0]: " + _textback.choices[0]);
Debug.Log("_textback.choices[0].message: " + _textback.choices[0].message);
Debug.Log("_textback.choices[0].message.content: " + _textback.choices[0].message.content);
if (_textback != null && _textback.choices != null && _textback.choices.Count > 0)
{
_callback(_textback.choices[0].message.content);
}
else { Debug.LogError("Request Error: Invalid response data."); }
}
else
{ Debug.LogError("Request Error: Empty response data."); }
}
else
{
Debug.LogError("Request Error: " + request.responseCode);
}
}
///
/// 用于退出应用程序,
///
public void Quit()
{ Application.Quit(); }
void Update()
{
if (Input.GetKeyDown(KeyCode.Escape))
{ Application.Quit(); }
if (Input.GetKeyDown(KeyCode.Return))
{ SendData(); }
}
[System.Serializable]
public class TextCallback
{
public string id;
public string created;
public string model;
public List choices;
[System.Serializable]
public class TextSample
{
public string index;
public Messages message;
[System.Serializable]
public class Messages
{
public string role;
public string content;
}
public string finish_reason;
}
}