ND4J/定制操作

package org.nd4j.examples;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.factory.Nd4j;

/**
 *
 * 定制操作是那些在C++ (libnd4j)中定义的,它们还没有被映射为具有Java便利方法,如1.0.0-beta版本。
 * 最终,所有的操作都将被映射成具有Java便利性的方法。同时,可以使用下面显示的“DynamicCustomOp”方法访问它们
 *
 */
public class CustomOpsExamples {


    public static void main(String[] args){

        //第一个示例:反转操作。此操作沿指定维度反转值。
        //c++代码:https://github.com/deeplearning4j/libnd4j/blob/master/include/ops/declarable/generic/transforms/reverse.cpp#L15
        
        //生成线性间隔向量,如下生成1-50,步长大小为50的标量 
        INDArray input = Nd4j.linspace(1, 50, 50)
                        //变形为5行10列的二维数组
                        .reshape(5,10);
        
        //按输入形状创建一个输出数组
        INDArray output = Nd4j.create(input.shape());
        
        //反转操作
        CustomOp op = DynamicCustomOp.builder("reverse")
            .addInputs(input)
            .addOutputs(output)
            //沿维度0(行)进行反转
            .addIntegerArguments(0)
            .build();
        
        //执行操作
        Nd4j.getExecutioner().exec(op);

        System.out.println("反转前");
        System.out.println(input);
        System.out.println();
        System.out.println("反转后");
        System.out.println(output);

       

        
        //另一个例子:网格
        //c++ 代码: https://github.com/deeplearning4j/libnd4j/blob/master/include/ops/declarable/generic/broadcastable/meshgrid.cpp
        //创建[[         0,    0.3333,    0.6667,    1.0000]],即从0开始到1,生成4个数字,每个数字值=((1-0)/(4-1))n   n=0,1,2,3
        INDArray input1 = Nd4j.linspace(0, 1, 4);
        //创建[[         0,    0.2500,    0.5000,    0.7500,    1.0000]],即从0开始到1,生成5个数字,每个数字值=((1-0)/(5-1))n  n=0,1,2,3,4
        INDArray input2 = Nd4j.linspace(0, 1, 5);
        
        //创建5行4列的二维数组
        INDArray output1 = Nd4j.create(5,4);
        //创建5行4列的二维数组
        INDArray output2 = Nd4j.create(5,4);

        //meshgrid详情参考https://blog.csdn.net/u013346007/article/details/54581253
        op = DynamicCustomOp.builder("meshgrid")
            .addInputs(input1, input2)
            .addOutputs(output1,output2)
            .build();
        
        Nd4j.getExecutioner().exec(op);
        
        System.out.println("meshgrid前");
        System.out.println(input1 + "\n\n" + input2);
        
        System.out.println("meshgrid后");
        //output1=input1列数*input2列数=4*5的矩阵,即用input1垂直堆叠n=input2列数 次
        //output2=input2转置后水平堆叠n=input1列数 次
        System.out.println(output1 + "\n\n"+output2);

    }

}

运行结果

反转前
[[    1.0000,    2.0000,    3.0000,    4.0000,    5.0000,    6.0000,    7.0000,    8.0000,    9.0000,   10.0000], 
 [   11.0000,   12.0000,   13.0000,   14.0000,   15.0000,   16.0000,   17.0000,   18.0000,   19.0000,   20.0000], 
 [   21.0000,   22.0000,   23.0000,   24.0000,   25.0000,   26.0000,   27.0000,   28.0000,   29.0000,   30.0000], 
 [   31.0000,   32.0000,   33.0000,   34.0000,   35.0000,   36.0000,   37.0000,   38.0000,   39.0000,   40.0000], 
 [   41.0000,   42.0000,   43.0000,   44.0000,   45.0000,   46.0000,   47.0000,   48.0000,   49.0000,   50.0000]]

反转后
[[   10.0000,    9.0000,    8.0000,    7.0000,    6.0000,    5.0000,    4.0000,    3.0000,    2.0000,    1.0000], 
 [   20.0000,   19.0000,   18.0000,   17.0000,   16.0000,   15.0000,   14.0000,   13.0000,   12.0000,   11.0000], 
 [   30.0000,   29.0000,   28.0000,   27.0000,   26.0000,   25.0000,   24.0000,   23.0000,   22.0000,   21.0000], 
 [   40.0000,   39.0000,   38.0000,   37.0000,   36.0000,   35.0000,   34.0000,   33.0000,   32.0000,   31.0000], 
 [   50.0000,   49.0000,   48.0000,   47.0000,   46.0000,   45.0000,   44.0000,   43.0000,   42.0000,   41.0000]]
meshgrid前
[[         0,    0.3333,    0.6667,    1.0000]]

[[         0,    0.2500,    0.5000,    0.7500,    1.0000]]
meshgrid后
[[         0,    0.3333,    0.6667,    1.0000], 
 [         0,    0.3333,    0.6667,    1.0000], 
 [         0,    0.3333,    0.6667,    1.0000], 
 [         0,    0.3333,    0.6667,    1.0000], 
 [         0,    0.3333,    0.6667,    1.0000]]

[[         0,         0,         0,         0], 
 [    0.2500,    0.2500,    0.2500,    0.2500], 
 [    0.5000,    0.5000,    0.5000,    0.5000], 
 [    0.7500,    0.7500,    0.7500,    0.7500], 
 [    1.0000,    1.0000,    1.0000,    1.0000]]


翻译:风一样的男子

image

你可能感兴趣的:(ND4J/定制操作)