Storm的学习(七)

这篇博客开始正式进入机器学习,以已经存在的机器学习框架,trident-ml开始讲解,这是一个开源项目(https://github.com/pmerienne/trident-ml),里面实现的机器学习算法不是很多,而且矩阵操作这一块也没有很好的实现,并且我的理解是Storm的这种学习方式适合在线学习的训练方式,这个算法包里面实现了分类,回归,聚类,自然语言处理以及统计等等,关于自然语言处理这个实在是过于复杂,不打算涉及这个方面,因此会说一说其它几种常见的算法的实现原理,并且读一读源代码的实现原理。先从常见的统计问题开始,比如:求均值,最大值,最小值,方差等等。。。
下面是trident-ml处理数据的基本格式,Instance有两个成员,label和features。

package com.github.pmerienne.trident.ml.core;

import java.io.Serializable;
import java.util.Arrays;

public class Instance<L> implements Serializable {

    private static final long serialVersionUID = -5378422729499109652L;

    public final L label;
    public final double[] features;

    public Instance(L label, double[] features) {
        this.label = label;//带标记的数据
        this.features = features;//实例的特征
    }

    public Instance(double[] features) {
        this.label = null;
        this.features = features;
    }

    public L getLabel() {
        return label;
    }

    public double[] getFeatures() {
        return features;
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + Arrays.hashCode(features);
        result = prime * result + ((label == null) ? 0 : label.hashCode());
        return result;
    }

    @SuppressWarnings("rawtypes")
    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        Instance other = (Instance) obj;
        if (!Arrays.equals(features, other.features))
            return false;
        if (label == null) {
            if (other.label != null)
                return false;
        } else if (!label.equals(other.label))
            return false;
        return true;
    }

    @Override
    public String toString() {
        return "Instance [label=" + label + ", features=" + Arrays.toString(features) + "]";
    }

}

还是直接先来一个Demo吧。。

package demo01;

import backtype.storm.Config;
import backtype.storm.StormSubmitter;
import backtype.storm.generated.AlreadyAliveException;
import backtype.storm.generated.InvalidTopologyException;
import backtype.storm.tuple.Fields;

import com.github.pmerienne.trident.ml.preprocessing.InstanceCreator;
import com.github.pmerienne.trident.ml.stats.StreamStatistics;
import com.github.pmerienne.trident.ml.stats.StreamStatisticsQuery;
import com.github.pmerienne.trident.ml.stats.StreamStatisticsUpdater;
import com.github.pmerienne.trident.ml.testing.RandomFeaturesSpout;

import storm.trident.TridentState;
import storm.trident.TridentTopology;
import storm.trident.testing.MemoryMapState;

public class StatisticsDemo01 {

    public static void main(String[] args) throws AlreadyAliveException, InvalidTopologyException {
        // TODO Auto-generated method stub
        TridentTopology topology = new TridentTopology();
        TridentState State = topology
                .newStream("randomFeatures", new RandomFeaturesSpout())
                .each(new Fields("x0","x1"),new InstanceCreator(), new Fields("instance"))
                .partitionPersist(new MemoryMapState.Factory(), new Fields("instance"), new StreamStatisticsUpdater("randomFeaturesStream", StreamStatistics.fixed())); 
        topology.newDRPCStream("queryStats")
            .stateQuery(State,new StreamStatisticsQuery("randomFeaturesStream"), new Fields("streamStats"));
        Config config = new Config();
         config.setDebug(true);
         StormSubmitter.submitTopology("demo02", config,topology.build());
        }
}

运行结果(如何运行请参考上一篇博客)

package demo01;

import org.apache.thrift7.TException;

import backtype.storm.generated.DRPCExecutionException;
import backtype.storm.utils.DRPCClient;

public class ClientDemo01 {

    public static void main(String[] args) throws TException, DRPCExecutionException {
        // TODO Auto-generated method stub
        DRPCClient client = new DRPCClient("localhost", 3772);
        System.out.println("开始调用....................");
        String result = client.execute("queryStats","any words");
        System.out.println(result);
    }

}
最终运行的结果如下:

Storm的学习(七)_第1张图片
下面还是看看代码里面究竟是什么意思。
数据源发送:

/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.testing;

import java.util.Map;
import java.util.Random;

import storm.trident.operation.TridentCollector;
import storm.trident.spout.IBatchSpout;
import backtype.storm.task.TopologyContext;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Values;

import com.google.common.base.Function;

public class RandomFeaturesSpout implements IBatchSpout {

    private static final long serialVersionUID = -5293861317274377258L;

    private int maxBatchSize = 10; //BatchSize默认值
    private int featureSize = 2; //特征数目
    private double variance = 3.0; //默认值

    private boolean withLabel = true;

    private final static Function<double[], Boolean> FEATURES_TO_LABEL = new Function<double[], Boolean>() {
        @Override
        public Boolean apply(double[] input) {
            double sum = 0;
            for (int i = 0; i < input.length; i++) {
                sum += input[i];
            }
            return sum > 0;
        }
    };

    private Random random = new Random();

    public RandomFeaturesSpout() {
    }

    public RandomFeaturesSpout(boolean withLabel) {
        this.withLabel = withLabel;
    }

    public RandomFeaturesSpout(int featureSize, double variance) {
        this.featureSize = featureSize;
        this.variance = variance;
    }

    public RandomFeaturesSpout(boolean withLabel, int featureSize, double variance) {
        this.withLabel = withLabel;
        this.featureSize = featureSize;
        this.variance = variance;
    }

    @SuppressWarnings("rawtypes")
    @Override
    public void open(Map conf, TopologyContext context) {
    }

    @Override
    public void emitBatch(long batchId, TridentCollector collector) {
        for (int i = 0; i < this.maxBatchSize; i++) {//一次发送maxBatchSize个tuple
            Values values = new Values();

            double[] features = new double[this.featureSize];//实例的特征数目
            for (int j = 0; j < this.featureSize; j++) {
                features[j] = j + this.random.nextGaussian() * this.variance;//模拟发送的特征数据
            }

            if (this.withLabel) {
                values.add(FEATURES_TO_LABEL.apply(features));//计算总和,每一个数组求和大于0 还是小于0
            }

            for (double feature : features) {
                values.add(feature);
            }

            collector.emit(values); // 发送特征+label
        }
        //设置延迟时间
        try {
            Thread.sleep(10000);
        } catch (InterruptedException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    @Override
    public void ack(long batchId) {
    }

    @Override
    public void close() {
    }

    @SuppressWarnings("rawtypes")
    @Override
    public Map getComponentConfiguration() {
        return null;
    }

    @Override
    public Fields getOutputFields() { // 输出数据的Fields
        String[] fieldNames;

        if (this.withLabel) { //用于判断数据是否包含标签,如果包含标签pos=0位置 "label" x1 ,x2.......
            fieldNames = new String[this.featureSize + 1];//多出来一个label位置
            fieldNames[0] = "label";
            for (int i = 0; i < this.featureSize; i++) {
                fieldNames[i + 1] = "x" + i;
            }
        } else {
            fieldNames = new String[this.featureSize];
            for (int i = 0; i < this.featureSize; i++) {
                fieldNames[i] = "x" + i;
            }
        }

        return new Fields(fieldNames); 
    }

    public int getMaxBatchSize() {
        return maxBatchSize;
    }

    public void setMaxBatchSize(int maxBatchSize) { 
        this.maxBatchSize = maxBatchSize;
    }

    public int getFeatureSize() {
        return featureSize;
    }

    public void setFeatureSize(int featureSize) {
        this.featureSize = featureSize;
    }

    public double getVariance() {
        return variance;
    }

    public void setVariance(double variance) {
        this.variance = variance;
    }

    public boolean isWithLabel() {
        return withLabel;
    }

    public void setWithLabel(boolean withLabel) {
        this.withLabel = withLabel;
    }

    public Random getRandom() {
        return random;
    }

    public void setRandom(Random random) {
        this.random = random;
    }
}

下面的代码是将原始数据转换为Instance:

/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.preprocessing;

import com.github.pmerienne.trident.ml.core.Instance;

import storm.trident.operation.BaseFunction;
import storm.trident.operation.TridentCollector;
import storm.trident.tuple.TridentTuple;
import backtype.storm.tuple.Values;

public class InstanceCreator<L> extends BaseFunction {

    private static final long serialVersionUID = 3312351524410720639L;

    private boolean withLabel = true;

    public InstanceCreator() {
    }

    public InstanceCreator(boolean withLabel) {
        this.withLabel = withLabel;
    }

    @Override
    public void execute(TridentTuple tuple, TridentCollector collector) {
        Instance<L> instance = this.createInstance(tuple);
        collector.emit(new Values(instance));
    }

    @SuppressWarnings("unchecked")
    protected Instance<L> createInstance(TridentTuple tuple) {
        Instance<L> instance = null;

        if (this.withLabel) {
            L label = (L) tuple.get(0); //标签信息
            double[] features = new double[tuple.size() - 1];//获取特征
            for (int i = 1; i < tuple.size(); i++) {
                features[i - 1] = tuple.getDouble(i);
            }

            instance = new Instance<L>(label, features); //新建Instance
        } else {
            double[] features = new double[tuple.size()];
            for (int i = 0; i < tuple.size(); i++) {
                features[i] = tuple.getDouble(i);
            }

            instance = new Instance<L>(features);
        }

        return instance;//返回instance
    }
}

下面是将stream的状态(统计的数据均值,方差等等),如何进行更新的。

/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.stats;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import com.github.pmerienne.trident.ml.core.Instance;
import com.github.pmerienne.trident.ml.util.KeysUtil;

import backtype.storm.tuple.Values;

import storm.trident.operation.TridentCollector;
import storm.trident.state.BaseStateUpdater;
import storm.trident.state.map.MapState;
import storm.trident.tuple.TridentTuple;

public class StreamStatisticsUpdater extends BaseStateUpdater<MapState<StreamStatistics>> {

    private static final long serialVersionUID = 1740717206768121351L;

    private String streamName;

    private StreamStatistics initialStatitics;

    public StreamStatisticsUpdater() {
    }

    public StreamStatisticsUpdater(String streamName, StreamStatistics initialStatitics) {
        this.streamName = streamName; //流名称
        this.initialStatitics = initialStatitics;//初始化状态
    }

    /** * 状态更新 * */
    @Override
    public void updateState(MapState<StreamStatistics> state, List<TridentTuple> tuples, TridentCollector collector) {
        StreamStatistics streamStatistics = this.getStreamStatistics(state);//根据state获取流
        List<Instance<?>> instances = this.extractInstances(tuples);               //获取实例,每一次有多个tuple

        // Update stream statistics
        this.updateStatistics(streamStatistics, instances); //统计状态更新

        // Save statistics
        state.multiPut(KeysUtil.toKeys(this.streamName), Arrays.asList(streamStatistics));//固定格式

        // Emit instance and stats for new stream 这个应该是为了后面的工作继续
        for (Instance<?> instance : instances) { 
            collector.emit(new Values(instance, streamStatistics));
        }
    }



    protected List<Instance<?>> extractInstances(List<TridentTuple> tuples) {//提取tuples中的Instances
        List<Instance<?>> instances = new ArrayList<Instance<?>>();

        Instance<?> instance;
        for (TridentTuple tuple : tuples) {
            instance = (Instance<?>) tuple.get(0);
            instances.add(instance);
        }

        return instances;
    }

    protected void updateStatistics(StreamStatistics streamStatistics, List<Instance<?>> instances) {  
        for (Instance<?> instance : instances) {
            streamStatistics.update(instance.features);//对获取到的stream更新状态
        }
    }

    protected StreamStatistics getStreamStatistics(MapState<StreamStatistics> state) {
        List<StreamStatistics> streamStatisticss = state.multiGet(KeysUtil.toKeys(this.streamName));    //取出state中的多个stream,每一个stream的数据用List表示
        StreamStatistics streamStatistics = null;
        if (streamStatisticss != null && !streamStatisticss.isEmpty()) {
            streamStatistics = streamStatisticss.get(0);                                //获取统计的stream
        }

        // Init it if necessary
        if (streamStatistics == null) {
            streamStatistics = this.initialStatitics;//初始化
        }
        return streamStatistics;
    }

}
/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.stats;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;

public class StreamStatistics implements Serializable {

    /** * Stream流统计 * */

    private static final long serialVersionUID = -3873210308112567893L;

    private Type type = Type.FIXED;    //默认的数据类型
    private Long adativeMaxSize = 1000L;

    private List<StreamFeatureStatistics> featuresStatistics = new ArrayList<StreamFeatureStatistics>();

    public StreamStatistics() {
    }

    public StreamStatistics(Type type) {
        this.type = type;
    }

    public StreamStatistics(Type type, Long adativeMaxSize) {
        this.type = type;
        this.adativeMaxSize = adativeMaxSize;
    }

    public void update(double[] features) {
        StreamFeatureStatistics featureStatistics;
        for (int i = 0; i < features.length; i++) {
            featureStatistics = this.getStreamStatistics(i);   //获取流信息,不够就创建。。。
            featureStatistics.update(features[i]);    //更新状态,每一个特征值更新一次状态
        }
    }

    private StreamFeatureStatistics getStreamStatistics(int index) {
        if (this.featuresStatistics.size() < index + 1) {//各个属性的统计状态
            StreamFeatureStatistics featureStatistics = this.createFeatureStatistics();
            this.featuresStatistics.add(featureStatistics);
        }
        return this.featuresStatistics.get(index);
    }

    private StreamFeatureStatistics createFeatureStatistics() {
        StreamFeatureStatistics featureStatistics = null;
        switch (this.type) {
        case FIXED:
            featureStatistics = new FixedStreamFeatureStatistics();
            break;
        case ADAPTIVE:
            featureStatistics = new AdaptiveStreamFeatureStatistics(this.adativeMaxSize);
            break;
        default:
            break;
        }
        return featureStatistics;
    }

    public static StreamStatistics fixed() {
        return new StreamStatistics();
    }

    public static StreamStatistics adaptive(Long maxSize) {
        return new StreamStatistics(Type.ADAPTIVE, maxSize);
    }

    public List<StreamFeatureStatistics> getFeaturesStatistics() {
        return featuresStatistics;
    }

    @Override
    public String toString() {
        return "StreamStatistics [type=" + type + ", featuresStatistics=" + featuresStatistics + "]";
    }

    public static enum Type {
        FIXED, ADAPTIVE;
    }
}
/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.stats;

import java.io.Serializable;

public class FixedStreamFeatureStatistics implements StreamFeatureStatistics, Serializable {

    private static final long serialVersionUID = -7406184811401750690L;

    private long count = 0L;
    private double mean = 0.0;
    private double m2 = 0.0;

    public void update(double feature) {
        this.count = this.count + 1;//元祖数目
        double delta = feature - this.mean;//特征-均值
        this.mean = this.mean + delta / this.count;//加权求均值()
        this.m2 = this.m2 + delta * (feature - this.mean);//用于计算方差和标准差
    }

    @Override
    public Long getCount() {
        return count;
    }

    @Override
    public Double getMean() {
        return mean;
    }

    @Override
    public Double getVariance() {
        return m2 / (count - 1);//计算方差
    }

    @Override
    public Double getStdDev() {
        return Math.sqrt(this.getVariance());//计算标准差
    }

    @Override
    public String toString() {
        return "SimpleStreamFeatureStatistics [m2=" + m2 + ", count=" + count + ", mean=" + mean + ", variance=" + getVariance() + ", stdDev=" + getStdDev()
                + "]";
    }

}

下面当我们使用DRPC进行实时查询结果的时候的代码,阅读起来比较容易:

/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.stats;

import java.util.List;

import com.github.pmerienne.trident.ml.util.KeysUtil;

import storm.trident.operation.TridentCollector;
import storm.trident.state.BaseQueryFunction;
import storm.trident.state.map.MapState;
import storm.trident.tuple.TridentTuple;
import backtype.storm.tuple.Values;

public class StreamStatisticsQuery extends BaseQueryFunction<MapState<StreamStatistics>, StreamStatistics> {

    private static final long serialVersionUID = -8853291509350751320L;

    private String streamName;

    public StreamStatisticsQuery(String streamName) {
        this.streamName = streamName;
    }

    @Override
    public List<StreamStatistics> batchRetrieve(MapState<StreamStatistics> state, List<TridentTuple> args) {
        List<StreamStatistics> statistics = state.multiGet(KeysUtil.toKeys(this.streamName));//从state中获取流的信息
        return statistics;
    }

    public void execute(TridentTuple tuple, StreamStatistics result, TridentCollector collector) {
        collector.emit(new Values(result));
    }
}

代码是比较容易读的,但是关于这一块的资料少的可怜,我本人也不是很清楚具体的实现原理,但是照着这个框架的话,写出来自己的东西还是比较容易的。下一篇博客就可以重点说一说机器学习部分了。

你可能感兴趣的:(storm,机器学习,开源项目,trident,Trident-ml)