了解了一些决策树的构建算法后,现在学习下随机森林。还是先上一些基本概念:
随机森林是一种比较新的机器学习模型。顾名思义,是用随机的方式建立一个森林,森林里面有很多的决策树组成,随机森林的每一棵决策树之间是没有关联的。在得到森林之后,当有一个新的输入样本进入的时候,就让森林中的每一棵决策树分别进行一下判断,看看这个样本应该属于哪一类(对于分类算法),然后看看哪一类被选择最多,就预测这个样本为那一类,即选举投票。
优点:
a. 在数据集上表现良好,两个随机性的引入,使得随机森林不容易陷入过拟合
b. 在当前的很多数据集上,相对其他算法有着很大的优势,两个随机性的引入,使得随机森林具有很好的抗噪声能力
c. 它能够处理很高维度(feature很多)的数据,并且不用做特征选择,对数据集的适应能力强:既能处理离散型数据,也能处理连续型数据,数据集无需规范化
d. 可生成一个Proximities=(pij)矩阵,用于度量样本之间的相似性: pij=aij/N, aij表示样本i和j出现在随机森林中同一个叶子结点的次数,N随机森林中树的颗数
e. 在创建随机森林的时候,对generlization error使用的是无偏估计
f. 训练速度快,可以得到变量重要性排序(两种:基于OOB误分率的增加量和基于分裂时的GINI下降量
g. 在训练过程中,能够检测到feature间的互相影响
h. 容易做成并行化方法
i. 实现比较简单
?
说白了,随机森林就是由许多个决策树构成,决策树使用什么算法取决于你。每个决策树构建需要的数据集是总数据集的随机抽取。同时每个抽取出来的数据集也不一定是包含所有特征属性,其含有的特征属性也是随机从总特征属性中随机抽取。随机森林等到所有决策树构建完成后,对样本数据集进行测试分类。最终的结果可以通过简单的投票选择获得,也可以通过复杂的权重计算获得等等。
?
下面是随机森林Java的简单实现
public class ForestBuilder extends BuilderAbstractImpl { /** 决策树数量*/ private int treeNum = 0; /** 随机属性数量*/ private int attributeNum = 0; /** 构建决策树Builder*/ private Builder builder = null; public ForestBuilder(int treeNum, Builder builder, int attributeNum) { this.treeNum = treeNum; this.builder = builder; this.attributeNum = attributeNum; } @Override public Object build(Data data) { ExecutorService pools = Executors.newFixedThreadPool( Runtime.getRuntime().availableProcessors()); List<Future<TreeNode>> futures = new ArrayList<Future<TreeNode>>(); for (int i = 0; i < treeNum; i++) { //线程里面去构建决策树 DecisionCallable callable = new DecisionCallable(data, builder, attributeNum); futures.add(pools.submit(callable)); } System.out.println("futures size: " + futures.size()); //等待线程创建完决策树 List<TreeNode> results = new ArrayList<TreeNode>(); handleFuture(futures, results); int futureLen = futures.size(); int resultsLen = results.size(); while (resultsLen < futureLen) { handleFuture(futures, results); resultsLen = results.size(); } pools.shutdown(); return results; } private void handleFuture(List<Future<TreeNode>> futures, List<TreeNode> results) { Iterator<Future<TreeNode>> iterator = futures.iterator(); while (iterator.hasNext()) { Future<TreeNode> future = iterator.next(); if (future.isDone()) { try { results.add(future.get()); iterator.remove(); } catch (Exception e) { e.printStackTrace(); } } } }}class DecisionCallable implements Callable<TreeNode> { private Data data = null; private int attributeNum = 0; private Builder builder = null; public DecisionCallable(Data data, Builder builder, int attributeNum) { this.data = data; this.builder = builder; this.attributeNum = attributeNum; } @Override public TreeNode call() throws Exception { Data randomData = DataLoader.loadRandom(data, attributeNum); Object object = builder.build(randomData); return null != object ? (TreeNode) object : null; } }
?
public class ForestNode extends Node { private static final long serialVersionUID = 1L; private List<TreeNode> treeNodes = null; public ForestNode(List<TreeNode> treeNodes) { this.treeNodes = treeNodes; } @Override public Object classify(Data data) { List<Object[]> results = new ArrayList<Object[]>(); for (TreeNode treeNode : treeNodes) { Object result = treeNode.classify(data); if (null != result) { results.add((Object[]) treeNode.classify(data)); } } return DataHandler.vote(results); } @Override public Object classify(Instance... instances) { List<Object[]> results = new ArrayList<Object[]>(); for (TreeNode treeNode : treeNodes) { Object result = treeNode.classify(instances); if (null != result) { results.add((Object[]) treeNode.classify(instances)); } } //投票选择 return DataHandler.vote(results); } }
?
?