Deep Learning 4J 学习(三) 利用神经网络学习绘图

1.先放置一张要绘制的图片(本例使用图像处理的经典图片lina)

Deep Learning 4J 学习(三) 利用神经网络学习绘图_第1张图片

 

Deep Learning 4J 学习(三) 利用神经网络学习绘图_第2张图片

 

2.ImageDrawer.java(不断的根据坐标值,调整输出的RGB值)

package com.jiantsing.test;

import javafx.application.Application;
import javafx.application.Platform;
import javafx.scene.Scene;
import javafx.scene.image.*;
import javafx.scene.layout.HBox;
import javafx.scene.paint.Color;
import javafx.stage.Stage;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;

/**
* JavaFX application to show a neural network learning to draw an image.
* Demonstrates how to feed an NN with externally originated data.
*
* This example uses JavaFX, which requires the Oracle JDK. Comment out this example if you use a different JDK.
* OpenJDK and openjfx have been reported to work fine.
*
* TODO: sample does not shut down correctly. Process must be stopped from the IDE.
*
* @author Robert Altena
*/
public class ImageDrawer extends Application {

    private Image originalImage; //The source image displayed on the left.
    private WritableImage composition; // Destination image generated by the NN.
    private MultiLayerNetwork nn; // THE nn.
    private DataSet ds; //Training data generated (only once) from the Original, used to train.
    private INDArray xyOut; //x,y grid to calculate the output image. Needs to be calculated once, then re-used.

    /**
     * Training the NN and updating the current graphical output.
     */
    private void onCalc(){
        nn.fit(ds);//训练神经网络
        drawImage();//再次输出测试
        Platform.runLater(this::onCalc);
    }

    @Override
    public void init(){
//        originalImage = new Image("/DataExamples/Mona_Lisa.png");
        originalImage = new Image("/images/lina.jpg");


        final int w = (int) originalImage.getWidth();
        final int h = (int) originalImage.getHeight();
        composition = new WritableImage(w, h); //Right image.

        ds = generateDataSet(originalImage);
        nn = createNN();

        // The x,y grid to calculate the NN output only needs to be calculated once.
        int numPoints = h * w;
        xyOut = Nd4j.zeros(numPoints, 2);//坐标输入矩阵
        for (int i = 0; i < w; i++) {
            double xp = (double) i / (double) (w - 1);
            for (int j = 0; j < h; j++) {
                int index = i + w * j;//0----w*h-1,用行下标表示所有的点阵
                double yp = (double) j / (double) (h - 1);

                xyOut.put(index, 0, xp); //2 inputs. x and y.,x坐标,double表示便于计算
                xyOut.put(index, 1, yp);//y坐标,double表示便于计算
            }
        }
//        System.out.println(xyOut);
        drawImage();//此时未调用神经网络,未开始训练
    }
    /**
     * Standard JavaFX start: Build the UI, display
     */
    @Override
    public void start(Stage primaryStage) {

        final int w = (int) originalImage.getWidth();
        final int h = (int) originalImage.getHeight();
        final int zoom = 1; // Our images are a tad small, display them enlarged to have something to look at.

        ImageView iv1 = new ImageView(); //Left image
        iv1.setImage(originalImage);
        iv1.setFitHeight( zoom* h);
        iv1.setFitWidth(zoom*w);

        ImageView iv2 = new ImageView();
        iv2.setImage(composition);
        iv2.setFitHeight( zoom* h);
        iv2.setFitWidth(zoom*w);

        HBox root = new HBox(); //build the scene.
        Scene scene = new Scene(root);
        root.getChildren().addAll(iv1, iv2);

        primaryStage.setTitle("Neural Network Drawing Demo.");
        primaryStage.setScene(scene);
        primaryStage.show();

        Platform.setImplicitExit(true);

        //Allow JavaFX do to it's thing, Initialize the Neural network when it feels like it.
        Platform.runLater(this::onCalc);
    }

    public static void main( String[] args )
    {
        launch(args);
    }

    /**
     * Build the Neural network.
     */
    private static MultiLayerNetwork createNN() {
        int seed = 2345;
        int iterations = 25; //<-- Just the one iteration per call to fit.
        double learningRate = 0.1;
        int numInputs = 2;   // x and y.,坐标输入
        int numHiddenNodes = 25;
        int numOutputs = 3 ; //R, G and B value.,颜色输出

        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
            .seed(seed)
            .iterations(iterations)
            .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
            .learningRate(learningRate)
            .weightInit(WeightInit.XAVIER)
            .updater(Updater.NESTEROVS).momentum(0.9)
            .list()
            .layer(0, new DenseLayer.Builder().nIn(numInputs).nOut(numHiddenNodes)
                .activation(Activation.IDENTITY)
                .build())
            .layer(1, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
                .activation(Activation.RELU)
                .build())
            .layer(2, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
                .activation(Activation.RELU)
                .build())
            .layer(3, new DenseLayer.Builder().nIn(numHiddenNodes).nOut(numHiddenNodes)
                .activation(Activation.RELU)
                .build())
            .layer(4, new OutputLayer.Builder(LossFunctions.LossFunction.L2)
                .activation(Activation.IDENTITY)
                .nIn(numHiddenNodes).nOut(numOutputs).build())
            .pretrain(false).backprop(true).build();

        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();

        return net;
    }

    /**
     * Process a javafx Image to be consumed by DeepLearning4J
     *
     * @param img Javafx image to process
     * @return DeepLearning4J DataSet.
     */
    private static DataSet generateDataSet(Image img) {
        int w = (int) img.getWidth();
        int h = (int) img.getHeight();
        int numPoints = h * w;

        PixelReader reader = img.getPixelReader();

        INDArray xy = Nd4j.zeros(numPoints, 2);//输入
        INDArray out = Nd4j.zeros(numPoints, 3);//答案

        //Simplest implementation first.
        for (int i = 0; i < w; i++) {
            double xp = (double) i / (double) (w - 1);
            for (int j = 0; j < h; j++) {
                Color c = reader.getColor(i, j);
                int index = i + w * j;//0----w*h-1,用行下标表示所有的点阵
                double yp = (double) j / (double) (h - 1);

                xy.put(index, 0, xp); //2 inputs. x and y.
                xy.put(index, 1, yp);

                out.put(index, 0, c.getRed());  //3 outputs. the RGB values.
                out.put(index, 1, c.getGreen());
                out.put(index, 2, c.getBlue());
            }
        }
        return new DataSet(xy, out);
    }

    /**
     * Make the Neural network draw the image.
     */
    private void drawImage() {
        int w = (int) composition.getWidth();
        int h = (int) composition.getHeight();

//        System.out.println(xyOut);
        INDArray out = nn.output(xyOut);//根据坐标矩阵,计算颜色输出矩阵
        PixelWriter writer = composition.getPixelWriter();

        for (int i = 0; i < w; i++) {
            for (int j = 0; j < h; j++) {
                int index = i + w * j;
                double red = capNNOutput(out.getDouble(index, 0));//capNNOutput转化时必须>=0,<=1
                double green = capNNOutput(out.getDouble(index, 1));
                double blue = capNNOutput(out.getDouble(index, 2));

                Color c = new Color(red, green, blue, 1.0);
                writer.setColor(i, j, c);
            }
        }
    }

    /**
     * Make sure the color values are >=0 and <=1
     */
    private static double capNNOutput(double x) {
        double tmp = (x<0.0) ? 0.0 : x;
        return (tmp > 1.0) ? 1.0 : tmp;
    }
}

 

3.效果图(运行起来有点卡,左边原图,右边电脑绘制,运行越久越逼真)

Deep Learning 4J 学习(三) 利用神经网络学习绘图_第3张图片

你可能感兴趣的:(Deep Learning 4J 学习(三) 利用神经网络学习绘图)