DJL-Java开发者动手学深度学习之归一化处理及源代码

在深度学习训练中,通过会对数据进行归一化处理。通常讲,归一化有两点好处:
1、使不同量纲的特征处于同一数值量级,减少方差大的特征的影响,使模型更准确。
2、加快学习算法的收敛速度。

MinMax归一化

将数据缩放到0和1之间,公式如下:
Y = X i − m i n ( X i ) m a x ( X i ) − m i n ( x i ) Y = \frac{X_i - min(X_i)}{max(X_i) - min(x_i)} Y=max(Xi)min(xi)Ximin(Xi)

标准归一化

将数据所防伪均值是0,方差为1的状态,公式如下:
Y = X i − μ δ Y = \frac{X_i - \mu}{\delta} Y=δXiμ
其缩放结果为:

均值归一化

将数据缩放到-1和1之间,公式如下:
Y = X i − m e a n ( X ) m a x ( X ) − m i n ( X ) Y = \frac{X_i - mean(X)}{max(X) - min(X)} Y=max(X)min(X)Ximean(X)

MinMax源代码解析

归一化的公式相对比较简单,只要照着公式去实现即可。
如简单的Python代码实现如下:

# 计算train数据集的最大值,最小值,平均值
maximums, minimums  = training_data.max(axis=0), training_data.min(axis=0)

# 对数据进行归一化处理

for i in range(feature_num):
data[:, i] = (data[:, i] - minimums[i]) / (maximums[i] - minimums[i])

Java的实现相比Python,代码多一点,不过实现内容一样,下面为DJL框架里封装的源代码,详细如下:

public class MinMaxScaler implements AutoCloseable {

    private NDArray fittedMin;
    private NDArray fittedMax;
    private NDArray fittedRange;
    private float minRange;
    private float maxRange = 1f;
    private boolean detached;

    public MinMaxScaler fit(NDArray data, int[] axises) {
        fittedMin = data.min(axises);
        fittedMax = data.max(axises);
        fittedRange = fittedMax.sub(fittedMin);
        if (detached) {
            detach();
        }
        return this;
    }

    public MinMaxScaler fit(NDArray data) {
        fit(data, new int[] {0});
        return this;
    }

    public NDArray transform(NDArray data) {
        if (fittedRange == null) {
            fit(data, new int[] {0});
        }
        NDArray std = data.sub(fittedMin).divi(fittedRange);
        return scale(std);
    }

    public NDArray transformi(NDArray data) {
        if (fittedRange == null) {
            fit(data, new int[] {0});
        }
        NDArray std = data.subi(fittedMin).divi(fittedRange);
        return scale(std);
    }

    private NDArray scale(NDArray std) {
        // we don't have to scale by custom range when range is default 0..1
        if (maxRange != 1f || minRange != 0f) {
            return std.muli(maxRange - minRange).addi(minRange);
        }
        return std;
    }

    private NDArray inverseScale(NDArray std) {
        // we don't have to scale by custom range when range is default 0..1
        if (maxRange != 1f || minRange != 0f) {
            return std.sub(minRange).divi(maxRange - minRange);
        }
        return std.duplicate();
    }

    private NDArray inverseScalei(NDArray std) {
        // we don't have to scale by custom range when range is default 0..1
        if (maxRange != 1f || minRange != 0f) {
            return std.subi(minRange).divi(maxRange - minRange);
        }
        return std;
    }

    public NDArray inverseTransform(NDArray data) {
        throwsIllegalStateWhenNotFitted();
        NDArray result = inverseScale(data);
        return result.muli(fittedRange).addi(fittedMin);
    }

    public NDArray inverseTransformi(NDArray data) {
        throwsIllegalStateWhenNotFitted();
        NDArray result = inverseScalei(data);
        return result.muli(fittedRange).addi(fittedMin);
    }

    private void throwsIllegalStateWhenNotFitted() {
        if (fittedRange == null) {
            throw new IllegalStateException("Min Max Scaler is not fitted");
        }
    }

    public MinMaxScaler detach() {
        detached = true;
        if (fittedMin != null) {
            fittedMin.detach();
        }
        if (fittedMax != null) {
            fittedMax.detach();
        }
        if (fittedRange != null) {
            fittedRange.detach();
        }
        return this;
    }

    public MinMaxScaler optRange(float minRange, float maxRange) {
        this.minRange = minRange;
        this.maxRange = maxRange;
        return this;
    }

    public NDArray getMin() {
        throwsIllegalStateWhenNotFitted();
        return fittedMin;
    }

    public NDArray getMax() {
        throwsIllegalStateWhenNotFitted();
        return fittedMax;
    }

    @Override
    public void close() {
        if (fittedMin != null) {
            fittedMin.close();
        }
        if (fittedMax != null) {
            fittedMax.close();
        }
        if (fittedRange != null) {
            fittedRange.close();
        }
    }
}

关注公众号【d2lcoder】,解锁更多深度学习内容。

你可能感兴趣的:(d2lcoder,Java开发者动手学习深度学习,DJL,深度学习,java,人工智能)