使用DL4J对CWRU数据集进行简单分类

目录

0. 前言

1. 关于CWRU数据集

2. 数据读取

3.数据预处理

4. 训练


0. 前言

最近开始搞智能故障诊断方面的工作,一上来面对的就是要各种炼丹。虽然众所周知在炼丹方面是python比较擅长,但由于本人已经写了不少年java形成了路径依赖,电脑上早就装好了dl4j的环境,本着能凑合用就绝不换引擎的原则,决定拿这玩意继续对付到不能用为止。

dl4j的生态其实并不是很好,文档都不是很全;在国内生态就更差,几乎还没见到有人用过。这次写这篇文章呢,也没寻思给谁看或者能帮到谁,就当是给自己做个备忘好了。安装和基础教程方面这里推荐一下b站寒沧的教学,当年第一次安的时候也是帮了我大忙。地址:

的个人空间_哔哩哔哩_Bilibili

1. 关于CWRU数据集

这是凯斯西储大学提供的一个数据集,在故障诊断领域属于入门级的数据集,大概相当于MNIST的地位。特征非常明显,分类也非常简单。下载地址:

Download a Data File | Case School of Engineering | Case Western Reserve University

因为数据太琐碎了也并不是全都用了,这里放上我自己用的,忘了从哪下的其中一部分数据集:

链接: https://pan.baidu.com/s/1fw3bCLV7qu1ZRQxVvoxJIw 提取码: 4v24 

文件说明:(别处抄的)

文件为Matlab格式

每个文件包含风扇和驱动端振动数据,以及电机转速,文件中文件变量命名如下:

DE - drive end accelerometer data 驱动端振动数据

FE - fan end accelerometer data 风扇端振动数据

BA - base accelerometer data 基座振动数据

time - time series data 时间序列数据

RPM- rpm during testing 单位转每分钟 除以60则为转频

数据采集频率分别为

数据集A:在12Khz采样频率下的驱动端轴承故障数据

数据集B:在48Khz采样频率下的驱动端轴承故障数据

数据集C:在12Khz采样频率下的风扇端轴承故障数据

数据集D:以及正常的轴承数据(采样频率应该是48k的)

数据集B解读:在48Khz采样频率下的驱动端轴承故障直径又分为0.007英寸、0.014英寸、0.028英寸三种类别,每种故障下负载又分为0马力、1马力、2马力、3马力。在每种故障的每种马力下有轴承内圈故障、轴承滚动体故障、轴承外环故障(由于轴承外环位置一般比较固定,因此外环故障又分为3点钟、6点钟和12点钟三种类别)。

2. 数据读取

因为上面下载的源文件是matlab的格式,打开看了一下发现是二进制的,不像csv或者json那样可以很方便地自己写parser去读,所以要用别的库,这里选择JMatio。

由于DL4J必须使用Java10还是11以上的版本(反正我用的Java17),而Jar版本的JMatio因为太古老了而对高版本Java不兼容,如果强行运行的话会报错:

java.lang.reflect.InaccessibleObjectException: Unable to make public jdk.internal.ref.Cleaner java.nio.DirectByteBuffer.cleaner() accessible。

看网上其他项目的解决办法是在JVM里加启动参数:

--add-opens=java.base/java.nio=ALL-UNNAMED

但对这个库仍不好使,会继续报错:

Exception in thread "main" java.lang.NoClassDefFoundError: sun/misc/Cleaner

搞得我很是头大,差点就因为这点小事弃坑了。鼓捣了半天,最后偶然间在翻这个项目的github的时候,发现两年前的一个commit修复了对高版本Java的支持。最后的解决办法是直接把这个版本的源码扔进项目里。当然也可以自己把代码打个包,以及修复后的版本应该在Maven也是有的,我这里就懒得弄了,毕竟只是个学习项目,能凑合使就得了。地址:

GitHub - gradusnikov/jmatio: JMatIO - Matlab's MAT-file I/O in JAVA

读数据的方式也很简单暴力,直接遍历文件夹下的所有文件,然后按文件名填信息就好了。

public class CWRUDataParser {

    public static void parse() throws Exception {
        var path = "你的path";
        for (String fname : new File(path).list()) {
            //System.out.println(fname);
            CWRUData d = new CWRUData();
            d.name = fname;
            CWRUDataManager.dataList.add(d);

            MatFileReader reader = new MatFileReader(path + "\\" + fname);
            var content = reader.getContent();

            if (fname.contains("_B")) {
                d.err_type = "Ball";
            } else if (fname.contains("_IR")) {
                d.err_type = "IR";
            } else if (fname.contains("_OR")) {
                d.err_type = "OR";
            } else {
                d.err_type = "Normal";
            }

            if (fname.contains("028")) {
                d.depth = 28;
            } else if (fname.contains("021")) {
                d.depth = 21;
            } else if (fname.contains("014")) {
                d.depth = 14;
            } else if (fname.contains("007")) {
                d.depth = 7;
            }

            if (fname.contains("_0_")) {
                d.load = 0;
            } else if (fname.contains("_1_")) {
                d.load = 1;
            } else if (fname.contains("_2_")) {
                d.load = 2;
            } else if (fname.contains("_3_")) {
                d.load = 3;
            }

            if (fname.contains("@3")) {
                d.pos = 3;
            } else if (fname.contains("@6")) {
                d.pos = 6;
            } else if (fname.contains("@12")) {
                d.pos = 12;
            }

            for (String key : content.keySet()) {
                var value = content.get(key);
                if (key.contains("DE")) {
                    d.DE = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
                    //System.out.println(d.DE.length);
                } else if (key.contains("FE")) {
                    d.FE = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
                    //System.out.println(d.FE.length);
                } else if (key.contains("BA")) {
                    d.BA = d.err_type.equals("Normal")?toDoubleArray4x(value): toDoubleArray(value);
                    //System.out.println(d.BA.length);
                } else if (key.contains("RPM")){
                    d.rpm = Double.valueOf(value.contentToString().split("=")[1]);
                    //System.out.println(d.rpm);
                }
            }
        }
    }

    public static double[] toDoubleArray(MLArray ma) {
        MLDouble md = (MLDouble) ma;
        int m = md.getM();
        double[] data = new double[m];
        for (int i = 0; i < m; i++) {
            data[i] = md.get(i, 0);
        }
        return data;
    }

    public static double[] toDoubleArray4x(MLArray ma) {
        MLDouble md = (MLDouble) ma;
        int m = md.getM();
        double[] data = new double[m/4];
        for (int i = 0; i < m/4; i++) {
            data[i] = (md.get(4*i, 0)+md.get(4*i+1, 0)+md.get(4*i+2, 0)+md.get(4*i+3, 0))/4;
        }
        return data;
    }
}

其中CWRUData是我自己封装的一个简单结构:

public class CWRUData {
    public String name;
    public String err_type;
    public int depth;
    public int pos = 0;
    public int load;

    public double[] DE;
    public double[] FE;
    public double[] BA;

    public double rpm;

    public List blocks = new ArrayList();

    public void cut(int size, int num) {
        for (int current = 0; current < DE.length-size ;current += DE.length/num) {
            double[] blockDE = new double[size];
            for (int i = 0; i < size; i++) {
                blockDE[i] = DE[i + current];
            }
            CWRUBlock block = new CWRUBlock(size, this);
            block.DE = blockDE;
            blocks.add(block);
        }
    }

    public int type() {
        if (err_type.equals("Ball")) {
            return depth / 7;
        }
        if (err_type.equals("IR")) {
            return 3 + depth / 7;
        }
        if (err_type.equals("OR")) {
            return 6 + depth / 7;
        }
        return 0;
    }

    public void print() {
        System.out.println("length:" + DE.length);
        System.out.println("block num:" + blocks.size());
    }
}

3.数据预处理

补充一下上面代码里没提到的东西。一是关于toDoubleArray4x():由于故障数据的采样率是12k,而正常数据是48k,为了保证二者频率一样,因此在存入正常数据的时候使用的是四合一平均池化的toDoubleArray4x()。二是type()的作用:是把数据集根据故障类型分为10类,正常的一类,故障的三类根据depth为7/14/21每个又分为三类,1+3*3=10。由于depth为28的数据不全,舍弃。OR错误类型的数据只使用位置为6的,其他位置的数据舍弃。

CWRU数据集提供的数据是一段很长很长的离散采样序列,需要用滑窗切成一段一段的才能处理。然后分析的时候一般只用DE的数据,因为其他的好像不全。切的方法上面已经给出了,用的时候只需要:

CWRUDataManager.dataList.forEach(d->d.cut(512, 400));

其中CWRUBlock也是自己封装的一个数据类型,代码:

public class CWRUBlock {

    public final int size;
    public final CWRUData source;
    public double[] DE;
    public double[] FE;
    public double[] BA;

    public CWRUBlock(int size, CWRUData source) {
        this.size = size;
        this.source = source;
    }
}

全都处理完之后就可以做dl4j的DataSet了。代码:

    public static DataSet genGenericDataSet1() {
        List blocks = new ArrayList();
        for (var data: dataList) {
            if (data.pos != 0 && data.pos != 6) continue;
            if (data.depth == 28) continue;
            data.blocks.forEach(blocks::add);
        }
        Collections.shuffle(blocks);
        INDArray[] input = blocks.stream().map(b->Nd4j.create(b.DE, 1, b.DE.length)).toArray(INDArray[]::new);
        INDArray inputs = Nd4j.vstack(input);
        INDArray[] output = blocks.stream().map(b->genOutputFromType(b.source.type())).toArray(INDArray[]::new);
        INDArray outputs = Nd4j.vstack(output);
        DataSet dataSet = new DataSet(inputs, outputs);
        return dataSet;
    }

这里为什么不用dl4j自带的shuffle,而要使用Collections.shuffle呢?这是因为我也不知道为啥dl4j自带的shuffle会直接崩掉jvm。。真是神奇的框架捏。

4. 训练

网络结构(一个非常简单的多层感知机):

    public static MultiLayerConfiguration CWRUANN() {
        MultiLayerConfiguration builder = new NeuralNetConfiguration.Builder()
                .seed(19260817L)
                .updater(new Sgd(0.01))
                .weightInit(WeightInit.XAVIER)
                .list()
                .layer(new DenseLayer.Builder().nIn(512).nOut(128)
                        .activation(Activation.RELU)
                        .build())
                .layer(new DenseLayer.Builder().nIn(128).nOut(32)
                        .activation(Activation.RELU)
                        .build())
                .layer(new DenseLayer.Builder().nIn(32).nOut(10)
                        .activation(Activation.RELU)
                        .build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
                        .activation(Activation.SOFTMAX)
                        .nIn(10).nOut(10).build())
                .build();
        return builder;
    }

开启监视器,以及训练:

    public static void ANN(DataSet train, DataSet test) throws Exception {

        MultiLayerNetwork model = new MultiLayerNetwork(NetFactory.CWRUANN());
        model.init();

        UIServer server = UIServer.getInstance();
        server.enableRemoteListener();
        StatsStorageRouter remoteUIRouter = new RemoteUIStatsStorageRouter("http://localhost:9000");
        model.setListeners(new StatsListener(remoteUIRouter));

        DataSetIterator iterator = getIter(train, 20);

        for (int x = 0; x< 10000; x++) {
            if (!iterator.hasNext()) {
                iterator = getIter(train, 20);
            }
            model.fit(iterator);
            if (x % 10 == 0) {
                model.save(new File("模型保存路径" + x + ".zip"), true);
                Evaluation eval = new Evaluation(10);
                INDArray output = model.output(test.getFeatures());
                eval.eval(test.getLabels(), output);
                log.info(eval.stats());
            }
        }
    }

    private static DataSetIterator getIter(final DataSet set, final int batchSize) {
        final List list = set.asList();
        Collections.shuffle(list, new Random());
        return new ListDataSetIterator(list,batchSize);
    }

训练结果:

========================Evaluation Metrics========================
 # of classes:    10
 Accuracy:        0.9831
 Precision:       0.9833
 Recall:          0.9834
 F1 Score:        0.9833
Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes)


=========================Confusion Matrix=========================
   0   1   2   3   4   5   6   7   8   9
-----------------------------------------
 157   0   0   0   0   0   0   0   0   0 | 0 = 0
   0 141   0   4   0   0   0   0   0   0 | 1 = 1
   1   0 177   0   1   2   0   0   0   0 | 2 = 2
   0   1   1 164   0   2   0   0   1   0 | 3 = 3
   0   0   1   0 163   1   1   0   0   0 | 4 = 4
   0   4   0   1   1 172   0   0   0   0 | 5 = 5
   0   0   0   0   0   0 145   0   0   0 | 6 = 6
   0   0   0   0   0   0   0 159   0   1 | 7 = 7
   0   0   0   0   0   0   0   0 161   0 | 8 = 8
   0   0   0   0   0   0   0   4   0 133 | 9 = 9

Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times
==================================================================

可以看出效果是非常不错的,真是简单的数据集捏。

你可能感兴趣的:(java,人工智能)