前面我们已经做好了模型的准备,下面就可以进行Spring Boot工程的开发了。这里我使用的是IntelliJ IDEA,借助Spring initializer创建一个Spring Boot Maven工程,添加Spring Boot Web starter和Lombok,确保IDEA已经安装了Lombok插件,这个是使用IDEA创建Spring Boot工程的基本操作,就不多说了。然后正式开始撸代码。
这里主要添加两个依赖项,分别是TensorFlow Java版(这里我用的是1.12.0)和commons-io,分别用于模型推理和模型载入:
<dependency>
<groupId>org.tensorflowgroupId>
<artifactId>tensorflowartifactId>
<version>1.12.0version>
dependency>
<dependency>
<groupId>commons-iogroupId>
<artifactId>commons-ioartifactId>
<version>2.6version>
dependency>
为了后期能直接把模型文件打到jar包之中而无需额外指定,这里将模型文件和词汇表引入资源文件。这么做还比较合适的重要原因也是我们这里用到的ALBERT本身的模型文件比较小,只有不到16M。如果是原生的BERT,模型文件动辄几百MB甚至更大,那直接放入资源文件可能就不是很合适了。这里单独创建一个文件夹albert-model存放:
这里借助Sprint Boot的_@_PostConstruct 注解,使Spring Boot工程构建完成后就自动把模型和词汇表加载到内存中,然后对应的类依赖注入,就不用每次推理时重新加载模型和创建相关对象,直接拿来用就可以,大大提高了效率。另外借助lombok,可自动生成Getter/Setter/Constructor使代码更加精简:
@Getter // lombok生成getter方法,方便取数据
@Component("loadALBERT")
public class LoadALBERT {
private Logger logger = LoggerFactory.getLogger(getClass());
private Session session; // TensorFlow Session 对象,可完成推理
private final int vectorDim = 312; // albert 向量维数
private final int maxSupportLen = 510; // albert 支持的最大长度,原始为512,去掉首尾的[CLS]和[SEP],即为510
private final String modelPath = "albert-model/albert_tiny_zh_google.pb"; // 模型资源文件路径
@PostConstruct
private void init(){
loadGraph(); // 调用加载方法
}
private void loadGraph(){
Graph graph = new Graph();
try{
// 获取资源文件中的模型文件输入流
InputStream inputStream = this.getClass().getClassLoader()
.getResourceAsStream(modelPath);
// 使用commons-io中的IOUtils将模型文件输入流转化为byte数组
byte[] graphPb = IOUtils.toByteArray(inputStream);
//初始化TensorFlow graph
graph.importGraphDef(graphPb);
// 把graph装入一个新的Session,可运行推理
this.session = new Session(graph);
logger.info("ALBERT checkpoint loaded @ {}, vector dimension - {}, maxSupportLen - {}",
modelPath, vectorDim, maxSupportLen);
} catch (Exception e){
logger.error("Failed to load ALBERT checkpoint @ {} ! - {}", modelPath, e.toString());
}
}
}
其余说明参见代码中的注释,另外使用slf4j打印日志。同理,词汇表也需要预先自加载,并使用HashMap存储字符-Token ID映射:
@Getter
@Component("loadVocab")
public class LoadVocab {
private Logger logger = LoggerFactory.getLogger(getClass());
private Map<String, Integer> vocabTable = new HashMap<>();
@PostConstruct
private void init(){
loadVocab();
}
private void loadVocab(){
try{
InputStreamReader inputReader = new InputStreamReader(
this.getClass().getClassLoader().getResourceAsStream("albert-model/vocab.txt"));
BufferedReader bf = new BufferedReader(inputReader);
String str;
int n = 0;
// 按行读,分配id
while ((str = bf.readLine()) != null) {
String lineWord = str.trim();
this.vocabTable.put(lineWord, n);
n++;
}
bf.close();
inputReader.close();
logger.info("ALBERT vocab loaded. total number - {} ", n);
}catch (Exception e){
logger.error("failed to load ALBERT vocab! - {}", e.toString());
}
}
}
这样以上两个类在Spring Boot启动时会自动调用加载方法,完成必要数据的加载,以供其他依赖对象使用。具体包路径是com.aiwiscal.albert.model。
定义了几个输入输出,以及中间结果类,比较零碎,可参照开源代码与注释理解:
类名 | 说明 |
---|---|
com.aiwiscal.albert.param.InputText | 原始输入类 |
com.aiwiscal.albert.param.InputTextValid | 有效输入类,对原始输入去空字符或补充[PAD] |
com.aiwiscal.albert.param.OutputToken | 模型推理输入类,包含tokenId和segmentId |
com.aiwiscal.albert.param.OutputVector | 最终请求返回类 |
这里实现了一个简易分词器,把原始输入文本简单清洗后,切分为单字符查表获得TokenId。这里就需要把之前加载好的词汇表映射注入进来(@Autowired):
@Component("tokenizer")
public class Tokenizer {
private Logger logger = LoggerFactory.getLogger(getClass());
@Autowired
private LoadVocab loadVocab; // 注入词汇表映射
@Autowired
private LoadALBERT loadALBERT; // 注入albert模型数据
// 原始输入清洗,切分字符并查表得到tokenId和segmentId
// 返回OutputToken类对象
public OutputToken tokenize(InputText inputText){
OutputToken outputToken = new OutputToken();
try{
// 简单清洗文本,生成有效输入类InputTextValid对象
InputTextValid inputTextValid = validateText(inputText);
// 获得tokenId
float[] tokenId = getTokenId(inputTextValid);
// 获得segmentId
float[] segmentId = getSegmentId(inputTextValid);
outputToken.setTokenId(tokenId);
outputToken.setSegmentId(segmentId);
outputToken.setSuccess(true);
outputToken.setInputTextValid(inputTextValid);
logger.debug("Text tokenized ...");
}catch (Exception e){
logger.error("Failed to tokenize the text - {}", e.toString());
outputToken.setSuccess(false);
}
return outputToken;
}
// ...... 以下省略
}
这里重点提一下TokenId的获取,通过注入的loadVocab对象获得其中的HashMap映射表,可以方便的取到对应字符的Token Id:
// 查表获得tokenId
private float[] getTokenId(InputTextValid inputTextValid){
String[] textTokenList = inputTextValid.getTextTokenList();
Map<String, Integer> vocabTable = loadVocab.getVocabTable(); // 获得注入loadVacab对象中的字符映射表
float[] tokenId = new float[textTokenList.length + 2];
tokenId[0] = 101; // 头部添加[CLS]标记,token id为101
tokenId[tokenId.length - 1] = 102; // 尾部添加[SEP]标记,token id为102
for (int i = 0; i < textTokenList.length; i++) {
String currentCharStr = textTokenList[i];
if(!vocabTable.containsKey(currentCharStr)){
currentCharStr = "[UNK]"; // 不在词汇表中的,设定为[UNK]
}
tokenId[i+1] = vocabTable.get(currentCharStr); // 查表
}
return tokenId;
}
而对于segment id,这里只是输入单文本,所以直接生成全0数组表示:
private float[] getSegmentId(InputTextValid inputTextValid){
return new float[inputTextValid.getTextTokenList().length + 2];
}
这是向量生成的核心,通过运行模型来获取结果,作为一个Service。需要注入模型数据,调用TensorFlow组件完成推理,参照以下代码中的注释:
@Service
public class InferALBERT {
private Logger logger = LoggerFactory.getLogger(getClass());
@Autowired
private LoadALBERT loadALBERT; // 注入模型数据
@Autowired
private Tokenizer tokenizer; // 注入分词器
private float[] inferArr(float[] inputToken, float[] inputSegment){
// 将1维数组扩展为2维以满足输入需要
float[][] inputToken2D = new float[1][inputToken.length];
float[][] inputSegment2D = new float[1][inputSegment.length];
System.arraycopy(inputToken, 0, inputToken2D[0], 0, inputToken.length);
System.arraycopy(inputSegment, 0, inputSegment2D[0], 0, inputSegment.length);
// 调用TensorFlow会话(Session)中的runner,实现模型推理
// 注入数据使用feed,取结果使用fetch,根据输入输出tensor的名称操作
Tensor result = loadALBERT.getSession().runner()
.feed("Input-Token", Tensor.create(inputToken2D))
.feed("Input-Segment", Tensor.create(inputSegment2D))
.fetch("output_1")
.run().get(0);
float[] ret = new float[loadALBERT.getVectorDim()];
// 将结果的Tensor对象内部数据拷贝至原生数组
result.copyTo(ret);
return ret;
}
public OutputVector infer(InputText inputText){
// 调用了上述inferArr
// 总流程,null检查,原始输入处理,分词,模型推理并打印相关日志,最终返回
// 具体代码省略 ......
// ......
}
}
这里用上了前面在Python环境中生成pb文件时得到的输入输出Tensor名称,以实现正确的数据注入,推理和结果获取。最终再把结果拷贝到Java原生数组方便后续的处理。当然在外层还封装了一个总流程方法,完成一系列操作并且返回最终的OutputVector对象。
基于@RestController注解完成对外部请求的处理,调用InferALBERT(Service)推理后返回给请求端:
@RestController
public class AlbertVecController {
private Logger logger = LoggerFactory.getLogger(getClass());
@Autowired
private InferALBERT inferALBERT; // 注入核心推理类Service
@Value("${server.port}")
private int port;
@RequestMapping("/")
public String runStatus(){
return String.format("======== ALBERT Vector Service is running @ port %d =======", this.port);
}
@PostMapping(path="/vector") //处理post向量生成请求
public OutputVector getVector(@RequestBody InputText inputText){
return inferALBERT.infer(inputText);
}
}
在配置文件(application.properties)里设定运行端口为7777(server.port=7777)。另外实现Spring Boot CommandLineRunner接口,使Spring Boot启动时对向量生成进行简单自检,位于com.aiwiscal.albert.starter.ServicePass类,若成功启动则输出日志**“ALBERT Vector Service is ready to listen …”。**
运行Spring Boot Application主应用,这里是com.aiwiscal.albert.AlbertVecApplication,查看日志输入:
. ____ _ __ _ _
/\\ / ___'_ __ _ _(_)_ __ __ _ \ \ \ \
( ( )\___ | '_ | '_| | '_ \/ _` | \ \ \ \
\\/ ___)| |_)| | | | | || (_| | ) ) ) )
' |____| .__|_| |_|_| |_\__, | / / / /
=========|_|==============|___/=/_/_/_/
:: Spring Boot :: (v2.2.5.RELEASE)
2020-03-28 10:17:18.167 INFO 10964 --- [ main] c.aiwiscal.albert.AlbertVecApplication : Starting AlbertVecApplication on LAPTOP-MVOM84AD with PID 10964 (E:\IdeaProjects\albert-vec\target\classes started by Wenhan in E:\IdeaProjects\albert-vec)
2020-03-28 10:17:18.171 INFO 10964 --- [ main] c.aiwiscal.albert.AlbertVecApplication : No active profile set, falling back to default profiles: default
2020-03-28 10:17:18.939 INFO 10964 --- [ main] o.s.b.w.embedded.tomcat.TomcatWebServer : Tomcat initialized with port(s): 7777 (http)
2020-03-28 10:17:18.946 INFO 10964 --- [ main] o.apache.catalina.core.StandardService : Starting service [Tomcat]
2020-03-28 10:17:18.946 INFO 10964 --- [ main] org.apache.catalina.core.StandardEngine : Starting Servlet engine: [Apache Tomcat/9.0.31]
2020-03-28 10:17:19.008 INFO 10964 --- [ main] o.a.c.c.C.[Tomcat].[localhost].[/] : Initializing Spring embedded WebApplicationContext
2020-03-28 10:17:19.009 INFO 10964 --- [ main] o.s.web.context.ContextLoader : Root WebApplicationContext: initialization completed in 782 ms
2020-03-28 10:17:20.094609: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2020-03-28 10:17:20.103 INFO 10964 --- [ main] com.aiwiscal.albert.model.LoadALBERT : ALBERT checkpoint loaded @ albert-model/albert_tiny_zh_google.pb, vector dimension - 312, maxSupportLen - 510
2020-03-28 10:17:20.112 INFO 10964 --- [ main] com.aiwiscal.albert.model.LoadVocab : ALBERT vocab loaded. total number - 21128
2020-03-28 10:17:20.207 INFO 10964 --- [ main] o.s.s.concurrent.ThreadPoolTaskExecutor : Initializing ExecutorService 'applicationTaskExecutor'
2020-03-28 10:17:20.333 INFO 10964 --- [ main] o.s.b.w.embedded.tomcat.TomcatWebServer : Tomcat started on port(s): 7777 (http) with context path ''
2020-03-28 10:17:20.336 INFO 10964 --- [ main] c.aiwiscal.albert.AlbertVecApplication : Started AlbertVecApplication in 2.562 seconds (JVM running for 3.697)
2020-03-28 10:17:20.339 INFO 10964 --- [ main] com.aiwiscal.albert.service.InferALBERT : Raw Input: Text - "你好 世 界, 世界你好!", ValidLength - 5
2020-03-28 10:17:20.339 INFO 10964 --- [ main] com.aiwiscal.albert.service.InferALBERT : Validated Input: Text - "你好世界,", ValidLength - 5
2020-03-28 10:17:21.756 INFO 10964 --- [ main] com.aiwiscal.albert.service.InferALBERT : ALBERT vector generation finished - time cost: 1417 ms.
2020-03-28 10:17:21.757 INFO 10964 --- [ main] com.aiwiscal.albert.starter.ServicePass : ALBERT Vector Service is ready to listen ...
看到最后的"2020-03-28 10:17:21.757 INFO 10964 — [ main] com.aiwiscal.albert.starter.ServicePass : ALBERT Vector Service is ready to listen …",说明已经成功启动了,在浏览器里输入127.0.0.1:7777回车:
以上大致说明了Spring Boot工程的开发思路和流程,总体还是比较简单的,下一步会在Python中对我们启动的上述向量服务进行请求,进行应用示例。
Python支持工程开源代码:https://github.com/Aiwiscal/albert-vec-support
Java主工程开源代码:https://github.com/Aiwiscal/albert-vec
喜欢请给star哦~