目录
0. 前言
1. 关于CWRU数据集
2. 数据读取
3.数据预处理
4. 训练
最近开始搞智能故障诊断方面的工作,一上来面对的就是要各种炼丹。虽然众所周知在炼丹方面是python比较擅长,但由于本人已经写了不少年java形成了路径依赖,电脑上早就装好了dl4j的环境,本着能凑合用就绝不换引擎的原则,决定拿这玩意继续对付到不能用为止。
dl4j的生态其实并不是很好,文档都不是很全;在国内生态就更差,几乎还没见到有人用过。这次写这篇文章呢,也没寻思给谁看或者能帮到谁,就当是给自己做个备忘好了。安装和基础教程方面这里推荐一下b站寒沧的教学,当年第一次安的时候也是帮了我大忙。地址:
的个人空间_哔哩哔哩_Bilibili
这是凯斯西储大学提供的一个数据集,在故障诊断领域属于入门级的数据集,大概相当于MNIST的地位。特征非常明显,分类也非常简单。下载地址:
Download a Data File | Case School of Engineering | Case Western Reserve University
因为数据太琐碎了也并不是全都用了,这里放上我自己用的,忘了从哪下的其中一部分数据集:
链接: https://pan.baidu.com/s/1fw3bCLV7qu1ZRQxVvoxJIw 提取码: 4v24
文件说明:(别处抄的)
文件为Matlab格式
每个文件包含风扇和驱动端振动数据,以及电机转速,文件中文件变量命名如下:
DE - drive end accelerometer data 驱动端振动数据
FE - fan end accelerometer data 风扇端振动数据
BA - base accelerometer data 基座振动数据
time - time series data 时间序列数据
RPM- rpm during testing 单位转每分钟 除以60则为转频
数据采集频率分别为:
数据集A:在12Khz采样频率下的驱动端轴承故障数据
数据集B:在48Khz采样频率下的驱动端轴承故障数据
数据集C:在12Khz采样频率下的风扇端轴承故障数据
数据集D:以及正常的轴承数据(采样频率应该是48k的)
数据集B解读:在48Khz采样频率下的驱动端轴承故障直径又分为0.007英寸、0.014英寸、0.028英寸三种类别,每种故障下负载又分为0马力、1马力、2马力、3马力。在每种故障的每种马力下有轴承内圈故障、轴承滚动体故障、轴承外环故障(由于轴承外环位置一般比较固定,因此外环故障又分为3点钟、6点钟和12点钟三种类别)。
因为上面下载的源文件是matlab的格式,打开看了一下发现是二进制的,不像csv或者json那样可以很方便地自己写parser去读,所以要用别的库,这里选择JMatio。
由于DL4J必须使用Java10还是11以上的版本(反正我用的Java17),而Jar版本的JMatio因为太古老了而对高版本Java不兼容,如果强行运行的话会报错:
java.lang.reflect.InaccessibleObjectException: Unable to make public jdk.internal.ref.Cleaner java.nio.DirectByteBuffer.cleaner() accessible。
看网上其他项目的解决办法是在JVM里加启动参数:
--add-opens=java.base/java.nio=ALL-UNNAMED
但对这个库仍不好使,会继续报错:
Exception in thread "main" java.lang.NoClassDefFoundError: sun/misc/Cleaner
搞得我很是头大,差点就因为这点小事弃坑了。鼓捣了半天,最后偶然间在翻这个项目的github的时候,发现两年前的一个commit修复了对高版本Java的支持。最后的解决办法是直接把这个版本的源码扔进项目里。当然也可以自己把代码打个包,以及修复后的版本应该在Maven也是有的,我这里就懒得弄了,毕竟只是个学习项目,能凑合使就得了。地址:
GitHub - gradusnikov/jmatio: JMatIO - Matlab's MAT-file I/O in JAVA
读数据的方式也很简单暴力,直接遍历文件夹下的所有文件,然后按文件名填信息就好了。
public class CWRUDataParser {
public static void parse() throws Exception {
var path = "你的path";
for (String fname : new File(path).list()) {
//System.out.println(fname);
CWRUData d = new CWRUData();
d.name = fname;
CWRUDataManager.dataList.add(d);
MatFileReader reader = new MatFileReader(path + "\\" + fname);
var content = reader.getContent();
if (fname.contains("_B")) {
d.err_type = "Ball";
} else if (fname.contains("_IR")) {
d.err_type = "IR";
} else if (fname.contains("_OR")) {
d.err_type = "OR";
} else {
d.err_type = "Normal";
}
if (fname.contains("028")) {
d.depth = 28;
} else if (fname.contains("021")) {
d.depth = 21;
} else if (fname.contains("014")) {
d.depth = 14;
} else if (fname.contains("007")) {
d.depth = 7;
}
if (fname.contains("_0_")) {
d.load = 0;
} else if (fname.contains("_1_")) {
d.load = 1;
} else if (fname.contains("_2_")) {
d.load = 2;
} else if (fname.contains("_3_")) {
d.load = 3;
}
if (fname.contains("@3")) {
d.pos = 3;
} else if (fname.contains("@6")) {
d.pos = 6;
} else if (fname.contains("@12")) {
d.pos = 12;
}
for (String key : content.keySet()) {
var value = content.get(key);
if (key.contains("DE")) {
d.DE = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
//System.out.println(d.DE.length);
} else if (key.contains("FE")) {
d.FE = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
//System.out.println(d.FE.length);
} else if (key.contains("BA")) {
d.BA = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
//System.out.println(d.BA.length);
} else if (key.contains("RPM")){
d.rpm = Double.valueOf(value.contentToString().split("=")[1]);
//System.out.println(d.rpm);
}
}
}
}
public static double[] toDoubleArray(MLArray ma) {
MLDouble md = (MLDouble) ma;
int m = md.getM();
double[] data = new double[m];
for (int i = 0; i < m; i++) {
data[i] = md.get(i, 0);
}
return data;
}
public static double[] toDoubleArray4x(MLArray ma) {
MLDouble md = (MLDouble) ma;
int m = md.getM();
double[] data = new double[m/4];
for (int i = 0; i < m/4; i++) {
data[i] = (md.get(4*i, 0)+md.get(4*i+1, 0)+md.get(4*i+2, 0)+md.get(4*i+3, 0))/4;
}
return data;
}
}
其中CWRUData是我自己封装的一个简单结构:
public class CWRUData {
public String name;
public String err_type;
public int depth;
public int pos = 0;
public int load;
public double[] DE;
public double[] FE;
public double[] BA;
public double rpm;
public List blocks = new ArrayList();
public void cut(int size, int num) {
for (int current = 0; current < DE.length-size ;current += DE.length/num) {
double[] blockDE = new double[size];
for (int i = 0; i < size; i++) {
blockDE[i] = DE[i + current];
}
CWRUBlock block = new CWRUBlock(size, this);
block.DE = blockDE;
blocks.add(block);
}
}
public int type() {
if (err_type.equals("Ball")) {
return depth / 7;
}
if (err_type.equals("IR")) {
return 3 + depth / 7;
}
if (err_type.equals("OR")) {
return 6 + depth / 7;
}
return 0;
}
public void print() {
System.out.println("length:" + DE.length);
System.out.println("block num:" + blocks.size());
}
}
补充一下上面代码里没提到的东西。一是关于toDoubleArray4x():由于故障数据的采样率是12k,而正常数据是48k,为了保证二者频率一样,因此在存入正常数据的时候使用的是四合一平均池化的toDoubleArray4x()。二是type()的作用:是把数据集根据故障类型分为10类,正常的一类,故障的三类根据depth为7/14/21每个又分为三类,1+3*3=10。由于depth为28的数据不全,舍弃。OR错误类型的数据只使用位置为6的,其他位置的数据舍弃。
CWRU数据集提供的数据是一段很长很长的离散采样序列,需要用滑窗切成一段一段的才能处理。然后分析的时候一般只用DE的数据,因为其他的好像不全。切的方法上面已经给出了,用的时候只需要:
CWRUDataManager.dataList.forEach(d->d.cut(512, 400));
其中CWRUBlock也是自己封装的一个数据类型,代码:
public class CWRUBlock {
public final int size;
public final CWRUData source;
public double[] DE;
public double[] FE;
public double[] BA;
public CWRUBlock(int size, CWRUData source) {
this.size = size;
this.source = source;
}
}
全都处理完之后就可以做dl4j的DataSet了。代码:
public static DataSet genGenericDataSet1() {
List blocks = new ArrayList();
for (var data: dataList) {
if (data.pos != 0 && data.pos != 6) continue;
if (data.depth == 28) continue;
data.blocks.forEach(blocks::add);
}
Collections.shuffle(blocks);
INDArray[] input = blocks.stream().map(b->Nd4j.create(b.DE, 1, b.DE.length)).toArray(INDArray[]::new);
INDArray inputs = Nd4j.vstack(input);
INDArray[] output = blocks.stream().map(b->genOutputFromType(b.source.type())).toArray(INDArray[]::new);
INDArray outputs = Nd4j.vstack(output);
DataSet dataSet = new DataSet(inputs, outputs);
return dataSet;
}
这里为什么不用dl4j自带的shuffle,而要使用Collections.shuffle呢?这是因为我也不知道为啥dl4j自带的shuffle会直接崩掉jvm。。真是神奇的框架捏。
网络结构(一个非常简单的多层感知机):
public static MultiLayerConfiguration CWRUANN() {
MultiLayerConfiguration builder = new NeuralNetConfiguration.Builder()
.seed(19260817L)
.updater(new Sgd(0.01))
.weightInit(WeightInit.XAVIER)
.list()
.layer(new DenseLayer.Builder().nIn(512).nOut(128)
.activation(Activation.RELU)
.build())
.layer(new DenseLayer.Builder().nIn(128).nOut(32)
.activation(Activation.RELU)
.build())
.layer(new DenseLayer.Builder().nIn(32).nOut(10)
.activation(Activation.RELU)
.build())
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.nIn(10).nOut(10).build())
.build();
return builder;
}
开启监视器,以及训练:
public static void ANN(DataSet train, DataSet test) throws Exception {
MultiLayerNetwork model = new MultiLayerNetwork(NetFactory.CWRUANN());
model.init();
UIServer server = UIServer.getInstance();
server.enableRemoteListener();
StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
model.setListeners(new StatsListener(remoteUIRouter));
DataSetIterator iterator = getIter(train, 20);
for (int x = 0; x< 10000; x++) {
if (!iterator.hasNext()) {
iterator = getIter(train, 20);
}
model.fit(iterator);
if (x % 10 == 0) {
model.save(new File("模型保存路径" + x + ".zip"), true);
Evaluation eval = new Evaluation(10);
INDArray output = model.output(test.getFeatures());
eval.eval(test.getLabels(), output);
log.info(eval.stats());
}
}
}
private static DataSetIterator getIter(final DataSet set, final int batchSize) {
final List list = set.asList();
Collections.shuffle(list, new Random());
return new ListDataSetIterator(list,batchSize);
}
训练结果:
========================Evaluation Metrics========================
# of classes: 10
Accuracy: 0.9831
Precision: 0.9833
Recall: 0.9834
F1 Score: 0.9833
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9
-----------------------------------------
157 0 0 0 0 0 0 0 0 0 | 0 = 0
0 141 0 4 0 0 0 0 0 0 | 1 = 1
1 0 177 0 1 2 0 0 0 0 | 2 = 2
0 1 1 164 0 2 0 0 1 0 | 3 = 3
0 0 1 0 163 1 1 0 0 0 | 4 = 4
0 4 0 1 1 172 0 0 0 0 | 5 = 5
0 0 0 0 0 0 145 0 0 0 | 6 = 6
0 0 0 0 0 0 0 159 0 1 | 7 = 7
0 0 0 0 0 0 0 0 161 0 | 8 = 8
0 0 0 0 0 0 0 4 0 133 | 9 = 9
Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================
可以看出效果是非常不错的,真是简单的数据集捏。