本文是基于学习网上教程聊天机器人中出现的问题而进行记录的文章,若要详细学习,请查看网页聊天机器人。
在网上找到了一个影视剧字幕聊天语料库。然后根据相关程度和历史记录对问句的回答排序,找到最优的那个。进行搜索和排序的过程。
总共有4个java文件:
Action.java
HttpServelelnboundHandler.java
NettyHttpServletResponse.java
Searcher.java
运行searcher.java
public static void main(String[] args) throws InterruptedException {
EventLoopGroup bossGroup = new NioEventLoopGroup(1);//指定bossgroup线程池大小,接收连接
EventLoopGroup workerGroup = new NioEventLoopGroup();//指定workergroup线程池大小,处理连接读写
ServerBootstrap b = new ServerBootstrap();//创建serverbootstrap对象,netty中用于启动xx
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)//指定使用java的nioserversocketchannel
.option(ChannelOption.SO_BACKLOG, 128)
.handler(new LoggingHandler(LogLevel.TRACE))
.childHandler(new ChannelInitializer() {//创建IOTHREAD的pipeline
@Override
public void initChannel(SocketChannel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
p.addLast("http-decoder", new HttpRequestDecoder());
p.addLast(new HttpObjectAggregator(65535));
p.addLast("http-encoder", new HttpResponseEncoder());
p.addLast(new HttpServerInboundHandler());
}
});
ChannelFuture f = b.bind(8080).sync();//同步等待绑定本地端口,也可以试试9999或者80
f.channel().closeFuture().sync();//等待服务器,socket关闭
}
调用httpserverinboundhandler.java
public class HttpServerInboundHandler extends SimpleChannelInboundHandler {
@Override
protected void messageReceived(ChannelHandlerContext ctx, FullHttpRequest msg) throws Exception {
NettyHttpServletResponse res = new NettyHttpServletResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
Action.doServlet(msg, res);
ChannelFuture future = ctx.channel().writeAndFlush(res);
future.addListener(ChannelFutureListener.CLOSE);
}
}
注意:channelread又名messagereceived
Action.java
该段代码实现内容,其他知识框架
public class Action {
private static final Logger log = Logger.getLogger(Action.class);//给检索的logger命名
private static final Logger logChat = Logger.getLogger("chat");//日志操作
private static final int MAX_RESULT = 10;
private static final int MAX_TOTAL_HITS = 1000000;
private static IndexSearcher indexSearcher = null;
/*打开索引目录*/
static {
IndexReader reader = null;
try {
reader = DirectoryReader.open(FSDirectory.open(new File("./index")));
} catch (IOException e) {
e.printStackTrace();
System.exit(-1);
}
indexSearcher = new IndexSearcher(reader);
}
public static void doServlet(FullHttpRequest req, NettyHttpServletResponse res) throws IOException, ParseException {
ByteBuf buf = null;
QueryStringDecoder qsd = new QueryStringDecoder(req.uri());//对URI进行解码,变成字符串,键值对的形式
/*映射*/
Map> mapParameters = qsd.parameters();//MAP对象的键是String类型,值是List集合中的String对象
/*列表*/
List l = mapParameters.get("q");//获得与键值"q"相关的值
if (null != l && l.size() == 1) {
String q = l.get(0);//返回集合中由参数0指定的索引位置的对象
// 如果是监控程序
if (q.equals("monitor")) {
JSONObject ret = new JSONObject();
ret.put("total", 1);
JSONObject item = new JSONObject();
item.put("answer", "alive");
JSONArray result = new JSONArray();
result.add(item);
ret.put("result", result);
buf = Unpooled.copiedBuffer(ret.toJSONString().getBytes("UTF-8"));
res.setContent(buf);
res.headers().set("Content-Type", "text/html; charset=UTF-8");
return;
}
log.info("question=" + q);
List clientIps = mapParameters.get("clientIp");
String clientIp = "";
if (null != clientIps && clientIps.size() == 1) {
clientIp = clientIps.get(0);
log.info("clientIp=" + clientIp);
}
Query query = null;
PriorityQueue pq = new PriorityQueue(MAX_RESULT) {
@Override
protected boolean lessThan(ScoreDoc a, ScoreDoc b) {
return a.score < b.score;
}
};
MyCollector collector = new MyCollector(pq);
JSONObject ret = new JSONObject();//jsonobject是一个可以与xml等相互转换的包,这里创建一个接送的对象就是我们的页面
/*包含文档命中总记录数、文档ID、文档得分*/
TopDocs topDocs = collector.topDocs();
Analyzer analyzer = new IKAnalyzer(true);
/*query parse是查询分析器,可以生成各种各样的query子对象
*"question" is the default field for query terms.*/
QueryParser qp = new QueryParser(Version.LUCENE_4_9, "question", analyzer);//把各种用户输入的符号串转为一个内部的query或query组
if (topDocs.totalHits == 0) {//查询中命中的次数
qp.setDefaultOperator(Operator.AND);
query = qp.parse(q);
log.info("lucene query=" + query.toString());
topDocs = indexSearcher.search(query, 20);
log.info("elapse " + collector.getElapse() + " " + collector.getElapse2());
}
if (topDocs.totalHits == 0) {
qp.setDefaultOperator(Operator.OR);
query = qp.parse(q);
log.info("lucene query=" + query.toString());
topDocs = indexSearcher.search(query, 20);
log.info("elapse " + collector.getElapse() + " " + collector.getElapse2());
}
ret.put("total", topDocs.totalHits);//总命中次数
ret.put("q", q);//提取将要检索的键值
JSONArray result = new JSONArray();
String firstAnswer = "";
for (ScoreDoc d : topDocs.scoreDocs) {
Document doc = indexSearcher.doc(d.doc);
String question = doc.get("question");
String answer = doc.get("answer");
JSONObject item = new JSONObject();
item.put("question",question);
if (firstAnswer.equals("")) {
firstAnswer = answer;
}
item.put("answer", answer);
/*lucene按照评分机智对每个documen进行打分,在返回的结果中按照得分进行降序排序*/
item.put("score", d.score);
item.put("doc", d.doc);
result.add(item);
}
ret.put("result", result);
log.info("response=" + ret);
logChat.info(clientIp + " [" + q + "] [" + firstAnswer + "]");
buf = Unpooled.copiedBuffer(ret.toJSONString().getBytes("UTF-8"));//返回一个拷贝给定的数据并使用它的 ByteBuf.
System.out.println(ret.toJSONString().getBytes("UTF-8"));
} else {
buf = Unpooled.copiedBuffer("error".getBytes("UTF-8"));
}
res.setContent(buf);
res.headers().set("Content-Type", "text/html; charset=UTF-8");
}
public static class MyCollector extends TopDocsCollector {
protected Scorer scorer;
protected AtomicReader reader;
protected int baseDoc;
protected HashSet set = new HashSet();
protected long elapse = 0;
protected long elapse2 = 0;
public long getElapse2() {
return elapse2;
}
public void setElapse2(long elapse2) {
this.elapse2 = elapse2;
}
public long getElapse() {
return elapse;
}
public void setElapse(long elapse) {
this.elapse = elapse;
}
protected MyCollector(PriorityQueue pq) {
super(pq);
}
@Override
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
}
@Override
public void collect(int doc) throws IOException {
long t1 = System.currentTimeMillis();
if (this.totalHits > MAX_TOTAL_HITS) {
return;
}
String answer = this.getAnswer(doc);
long t3 = System.currentTimeMillis();
this.elapse2 += t3 - t1;
if (set.contains(answer.hashCode())) {
return;
} else {
set.add(answer.hashCode());
ScoreDoc sd = new ScoreDoc(doc, this.scorer.score());
if (this.pq.size() >= MAX_RESULT) {
this.pq.updateTop();
this.pq.pop();
}
this.pq.add(sd);
this.totalHits++;
}
long t2 = System.currentTimeMillis();
this.elapse += t2 - t1;
}
@Override
public void setNextReader(AtomicReaderContext context) throws IOException {
this.reader = context.reader();
this.baseDoc = context.docBase;
}
@Override
public boolean acceptsDocsOutOfOrder() {
return false;
}
private String getAnswer(int doc) throws IOException {
Document d = indexSearcher.doc(doc);
return d.get("answer");
}
nettyhttpservletresponse.java
package com.shareditor.chatbotv1;
import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpHeaders;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
public class NettyHttpServletResponse extends DefaultHttpResponse implements FullHttpResponse {
private ByteBuf content;
public NettyHttpServletResponse(HttpVersion version, HttpResponseStatus status) {
super(version, status,true);
}
public HttpHeaders trailingHeaders() {
// TODO Auto-generated method stub
return null;
}
public void setContent(ByteBuf buf) {
this.content = buf;
}
public ByteBuf content() {
return content;
}
public int refCnt() {
// TODO Auto-generated method stub
return 0;
}
public boolean release() {
// TODO Auto-generated method stub
return false;
}
public boolean release(int decrement) {
// TODO Auto-generated method stub
return false;
}
public FullHttpResponse copy(ByteBuf newContent) {
// TODO Auto-generated method stub
return null;
}
public FullHttpResponse copy() {
// TODO Auto-generated method stub
return null;
}
public FullHttpResponse retain(int increment) {
// TODO Auto-generated method stub
return null;
}
public FullHttpResponse retain() {
// TODO Auto-generated method stub
return null;
}
public FullHttpResponse touch() {
// TODO Auto-generated method stub
return null;
}
public FullHttpResponse touch(Object hint) {
// TODO Auto-generated method stub
return null;
}
public FullHttpResponse duplicate() {
// TODO Auto-generated method stub
return null;
}
public FullHttpResponse setProtocolVersion(HttpVersion version) {
// TODO Auto-generated method stub
return null;
}
public FullHttpResponse setStatus(HttpResponseStatus status) {
// TODO Auto-generated method stub
return null;
}
}