这篇博客开始正式进入机器学习,以已经存在的机器学习框架,trident-ml开始讲解,这是一个开源项目(https://github.com/pmerienne/trident-ml),里面实现的机器学习算法不是很多,而且矩阵操作这一块也没有很好的实现,并且我的理解是Storm的这种学习方式适合在线学习的训练方式,这个算法包里面实现了分类,回归,聚类,自然语言处理以及统计等等,关于自然语言处理这个实在是过于复杂,不打算涉及这个方面,因此会说一说其它几种常见的算法的实现原理,并且读一读源代码的实现原理。先从常见的统计问题开始,比如:求均值,最大值,最小值,方差等等。。。
下面是trident-ml处理数据的基本格式,Instance有两个成员,label和features。
package com.github.pmerienne.trident.ml.core;
import java.io.Serializable;
import java.util.Arrays;
public class Instance<L> implements Serializable {
private static final long serialVersionUID = -5378422729499109652L;
public final L label;
public final double[] features;
public Instance(L label, double[] features) {
this.label = label;//带标记的数据
this.features = features;//实例的特征
}
public Instance(double[] features) {
this.label = null;
this.features = features;
}
public L getLabel() {
return label;
}
public double[] getFeatures() {
return features;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + Arrays.hashCode(features);
result = prime * result + ((label == null) ? 0 : label.hashCode());
return result;
}
@SuppressWarnings("rawtypes")
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
Instance other = (Instance) obj;
if (!Arrays.equals(features, other.features))
return false;
if (label == null) {
if (other.label != null)
return false;
} else if (!label.equals(other.label))
return false;
return true;
}
@Override
public String toString() {
return "Instance [label=" + label + ", features=" + Arrays.toString(features) + "]";
}
}
还是直接先来一个Demo吧。。
package demo01;
import backtype.storm.Config;
import backtype.storm.StormSubmitter;
import backtype.storm.generated.AlreadyAliveException;
import backtype.storm.generated.InvalidTopologyException;
import backtype.storm.tuple.Fields;
import com.github.pmerienne.trident.ml.preprocessing.InstanceCreator;
import com.github.pmerienne.trident.ml.stats.StreamStatistics;
import com.github.pmerienne.trident.ml.stats.StreamStatisticsQuery;
import com.github.pmerienne.trident.ml.stats.StreamStatisticsUpdater;
import com.github.pmerienne.trident.ml.testing.RandomFeaturesSpout;
import storm.trident.TridentState;
import storm.trident.TridentTopology;
import storm.trident.testing.MemoryMapState;
public class StatisticsDemo01 {
public static void main(String[] args) throws AlreadyAliveException, InvalidTopologyException {
// TODO Auto-generated method stub
TridentTopology topology = new TridentTopology();
TridentState State = topology
.newStream("randomFeatures", new RandomFeaturesSpout())
.each(new Fields("x0","x1"),new InstanceCreator(), new Fields("instance"))
.partitionPersist(new MemoryMapState.Factory(), new Fields("instance"), new StreamStatisticsUpdater("randomFeaturesStream", StreamStatistics.fixed()));
topology.newDRPCStream("queryStats")
.stateQuery(State,new StreamStatisticsQuery("randomFeaturesStream"), new Fields("streamStats"));
Config config = new Config();
config.setDebug(true);
StormSubmitter.submitTopology("demo02", config,topology.build());
}
}
运行结果(如何运行请参考上一篇博客)
package demo01;
import org.apache.thrift7.TException;
import backtype.storm.generated.DRPCExecutionException;
import backtype.storm.utils.DRPCClient;
public class ClientDemo01 {
public static void main(String[] args) throws TException, DRPCExecutionException {
// TODO Auto-generated method stub
DRPCClient client = new DRPCClient("localhost", 3772);
System.out.println("开始调用....................");
String result = client.execute("queryStats","any words");
System.out.println(result);
}
}
最终运行的结果如下:
/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.testing;
import java.util.Map;
import java.util.Random;
import storm.trident.operation.TridentCollector;
import storm.trident.spout.IBatchSpout;
import backtype.storm.task.TopologyContext;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Values;
import com.google.common.base.Function;
public class RandomFeaturesSpout implements IBatchSpout {
private static final long serialVersionUID = -5293861317274377258L;
private int maxBatchSize = 10; //BatchSize默认值
private int featureSize = 2; //特征数目
private double variance = 3.0; //默认值
private boolean withLabel = true;
private final static Function<double[], Boolean> FEATURES_TO_LABEL = new Function<double[], Boolean>() {
@Override
public Boolean apply(double[] input) {
double sum = 0;
for (int i = 0; i < input.length; i++) {
sum += input[i];
}
return sum > 0;
}
};
private Random random = new Random();
public RandomFeaturesSpout() {
}
public RandomFeaturesSpout(boolean withLabel) {
this.withLabel = withLabel;
}
public RandomFeaturesSpout(int featureSize, double variance) {
this.featureSize = featureSize;
this.variance = variance;
}
public RandomFeaturesSpout(boolean withLabel, int featureSize, double variance) {
this.withLabel = withLabel;
this.featureSize = featureSize;
this.variance = variance;
}
@SuppressWarnings("rawtypes")
@Override
public void open(Map conf, TopologyContext context) {
}
@Override
public void emitBatch(long batchId, TridentCollector collector) {
for (int i = 0; i < this.maxBatchSize; i++) {//一次发送maxBatchSize个tuple
Values values = new Values();
double[] features = new double[this.featureSize];//实例的特征数目
for (int j = 0; j < this.featureSize; j++) {
features[j] = j + this.random.nextGaussian() * this.variance;//模拟发送的特征数据
}
if (this.withLabel) {
values.add(FEATURES_TO_LABEL.apply(features));//计算总和,每一个数组求和大于0 还是小于0
}
for (double feature : features) {
values.add(feature);
}
collector.emit(values); // 发送特征+label
}
//设置延迟时间
try {
Thread.sleep(10000);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
@Override
public void ack(long batchId) {
}
@Override
public void close() {
}
@SuppressWarnings("rawtypes")
@Override
public Map getComponentConfiguration() {
return null;
}
@Override
public Fields getOutputFields() { // 输出数据的Fields
String[] fieldNames;
if (this.withLabel) { //用于判断数据是否包含标签,如果包含标签pos=0位置 "label" x1 ,x2.......
fieldNames = new String[this.featureSize + 1];//多出来一个label位置
fieldNames[0] = "label";
for (int i = 0; i < this.featureSize; i++) {
fieldNames[i + 1] = "x" + i;
}
} else {
fieldNames = new String[this.featureSize];
for (int i = 0; i < this.featureSize; i++) {
fieldNames[i] = "x" + i;
}
}
return new Fields(fieldNames);
}
public int getMaxBatchSize() {
return maxBatchSize;
}
public void setMaxBatchSize(int maxBatchSize) {
this.maxBatchSize = maxBatchSize;
}
public int getFeatureSize() {
return featureSize;
}
public void setFeatureSize(int featureSize) {
this.featureSize = featureSize;
}
public double getVariance() {
return variance;
}
public void setVariance(double variance) {
this.variance = variance;
}
public boolean isWithLabel() {
return withLabel;
}
public void setWithLabel(boolean withLabel) {
this.withLabel = withLabel;
}
public Random getRandom() {
return random;
}
public void setRandom(Random random) {
this.random = random;
}
}
下面的代码是将原始数据转换为Instance:
/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.preprocessing;
import com.github.pmerienne.trident.ml.core.Instance;
import storm.trident.operation.BaseFunction;
import storm.trident.operation.TridentCollector;
import storm.trident.tuple.TridentTuple;
import backtype.storm.tuple.Values;
public class InstanceCreator<L> extends BaseFunction {
private static final long serialVersionUID = 3312351524410720639L;
private boolean withLabel = true;
public InstanceCreator() {
}
public InstanceCreator(boolean withLabel) {
this.withLabel = withLabel;
}
@Override
public void execute(TridentTuple tuple, TridentCollector collector) {
Instance<L> instance = this.createInstance(tuple);
collector.emit(new Values(instance));
}
@SuppressWarnings("unchecked")
protected Instance<L> createInstance(TridentTuple tuple) {
Instance<L> instance = null;
if (this.withLabel) {
L label = (L) tuple.get(0); //标签信息
double[] features = new double[tuple.size() - 1];//获取特征
for (int i = 1; i < tuple.size(); i++) {
features[i - 1] = tuple.getDouble(i);
}
instance = new Instance<L>(label, features); //新建Instance
} else {
double[] features = new double[tuple.size()];
for (int i = 0; i < tuple.size(); i++) {
features[i] = tuple.getDouble(i);
}
instance = new Instance<L>(features);
}
return instance;//返回instance
}
}
下面是将stream的状态(统计的数据均值,方差等等),如何进行更新的。
/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.stats;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import com.github.pmerienne.trident.ml.core.Instance;
import com.github.pmerienne.trident.ml.util.KeysUtil;
import backtype.storm.tuple.Values;
import storm.trident.operation.TridentCollector;
import storm.trident.state.BaseStateUpdater;
import storm.trident.state.map.MapState;
import storm.trident.tuple.TridentTuple;
public class StreamStatisticsUpdater extends BaseStateUpdater<MapState<StreamStatistics>> {
private static final long serialVersionUID = 1740717206768121351L;
private String streamName;
private StreamStatistics initialStatitics;
public StreamStatisticsUpdater() {
}
public StreamStatisticsUpdater(String streamName, StreamStatistics initialStatitics) {
this.streamName = streamName; //流名称
this.initialStatitics = initialStatitics;//初始化状态
}
/** * 状态更新 * */
@Override
public void updateState(MapState<StreamStatistics> state, List<TridentTuple> tuples, TridentCollector collector) {
StreamStatistics streamStatistics = this.getStreamStatistics(state);//根据state获取流
List<Instance<?>> instances = this.extractInstances(tuples); //获取实例,每一次有多个tuple
// Update stream statistics
this.updateStatistics(streamStatistics, instances); //统计状态更新
// Save statistics
state.multiPut(KeysUtil.toKeys(this.streamName), Arrays.asList(streamStatistics));//固定格式
// Emit instance and stats for new stream 这个应该是为了后面的工作继续
for (Instance<?> instance : instances) {
collector.emit(new Values(instance, streamStatistics));
}
}
protected List<Instance<?>> extractInstances(List<TridentTuple> tuples) {//提取tuples中的Instances
List<Instance<?>> instances = new ArrayList<Instance<?>>();
Instance<?> instance;
for (TridentTuple tuple : tuples) {
instance = (Instance<?>) tuple.get(0);
instances.add(instance);
}
return instances;
}
protected void updateStatistics(StreamStatistics streamStatistics, List<Instance<?>> instances) {
for (Instance<?> instance : instances) {
streamStatistics.update(instance.features);//对获取到的stream更新状态
}
}
protected StreamStatistics getStreamStatistics(MapState<StreamStatistics> state) {
List<StreamStatistics> streamStatisticss = state.multiGet(KeysUtil.toKeys(this.streamName)); //取出state中的多个stream,每一个stream的数据用List表示
StreamStatistics streamStatistics = null;
if (streamStatisticss != null && !streamStatisticss.isEmpty()) {
streamStatistics = streamStatisticss.get(0); //获取统计的stream
}
// Init it if necessary
if (streamStatistics == null) {
streamStatistics = this.initialStatitics;//初始化
}
return streamStatistics;
}
}
/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.stats;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
public class StreamStatistics implements Serializable {
/** * Stream流统计 * */
private static final long serialVersionUID = -3873210308112567893L;
private Type type = Type.FIXED; //默认的数据类型
private Long adativeMaxSize = 1000L;
private List<StreamFeatureStatistics> featuresStatistics = new ArrayList<StreamFeatureStatistics>();
public StreamStatistics() {
}
public StreamStatistics(Type type) {
this.type = type;
}
public StreamStatistics(Type type, Long adativeMaxSize) {
this.type = type;
this.adativeMaxSize = adativeMaxSize;
}
public void update(double[] features) {
StreamFeatureStatistics featureStatistics;
for (int i = 0; i < features.length; i++) {
featureStatistics = this.getStreamStatistics(i); //获取流信息,不够就创建。。。
featureStatistics.update(features[i]); //更新状态,每一个特征值更新一次状态
}
}
private StreamFeatureStatistics getStreamStatistics(int index) {
if (this.featuresStatistics.size() < index + 1) {//各个属性的统计状态
StreamFeatureStatistics featureStatistics = this.createFeatureStatistics();
this.featuresStatistics.add(featureStatistics);
}
return this.featuresStatistics.get(index);
}
private StreamFeatureStatistics createFeatureStatistics() {
StreamFeatureStatistics featureStatistics = null;
switch (this.type) {
case FIXED:
featureStatistics = new FixedStreamFeatureStatistics();
break;
case ADAPTIVE:
featureStatistics = new AdaptiveStreamFeatureStatistics(this.adativeMaxSize);
break;
default:
break;
}
return featureStatistics;
}
public static StreamStatistics fixed() {
return new StreamStatistics();
}
public static StreamStatistics adaptive(Long maxSize) {
return new StreamStatistics(Type.ADAPTIVE, maxSize);
}
public List<StreamFeatureStatistics> getFeaturesStatistics() {
return featuresStatistics;
}
@Override
public String toString() {
return "StreamStatistics [type=" + type + ", featuresStatistics=" + featuresStatistics + "]";
}
public static enum Type {
FIXED, ADAPTIVE;
}
}
/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.stats;
import java.io.Serializable;
public class FixedStreamFeatureStatistics implements StreamFeatureStatistics, Serializable {
private static final long serialVersionUID = -7406184811401750690L;
private long count = 0L;
private double mean = 0.0;
private double m2 = 0.0;
public void update(double feature) {
this.count = this.count + 1;//元祖数目
double delta = feature - this.mean;//特征-均值
this.mean = this.mean + delta / this.count;//加权求均值()
this.m2 = this.m2 + delta * (feature - this.mean);//用于计算方差和标准差
}
@Override
public Long getCount() {
return count;
}
@Override
public Double getMean() {
return mean;
}
@Override
public Double getVariance() {
return m2 / (count - 1);//计算方差
}
@Override
public Double getStdDev() {
return Math.sqrt(this.getVariance());//计算标准差
}
@Override
public String toString() {
return "SimpleStreamFeatureStatistics [m2=" + m2 + ", count=" + count + ", mean=" + mean + ", variance=" + getVariance() + ", stdDev=" + getStdDev()
+ "]";
}
}
下面当我们使用DRPC进行实时查询结果的时候的代码,阅读起来比较容易:
/** * Copyright 2013-2015 Pierre Merienne * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */
package com.github.pmerienne.trident.ml.stats;
import java.util.List;
import com.github.pmerienne.trident.ml.util.KeysUtil;
import storm.trident.operation.TridentCollector;
import storm.trident.state.BaseQueryFunction;
import storm.trident.state.map.MapState;
import storm.trident.tuple.TridentTuple;
import backtype.storm.tuple.Values;
public class StreamStatisticsQuery extends BaseQueryFunction<MapState<StreamStatistics>, StreamStatistics> {
private static final long serialVersionUID = -8853291509350751320L;
private String streamName;
public StreamStatisticsQuery(String streamName) {
this.streamName = streamName;
}
@Override
public List<StreamStatistics> batchRetrieve(MapState<StreamStatistics> state, List<TridentTuple> args) {
List<StreamStatistics> statistics = state.multiGet(KeysUtil.toKeys(this.streamName));//从state中获取流的信息
return statistics;
}
public void execute(TridentTuple tuple, StreamStatistics result, TridentCollector collector) {
collector.emit(new Values(result));
}
}
代码是比较容易读的,但是关于这一块的资料少的可怜,我本人也不是很清楚具体的实现原理,但是照着这个框架的话,写出来自己的东西还是比较容易的。下一篇博客就可以重点说一说机器学习部分了。