随机森林

      按照此方式,我们可以采样出T个含m个训练样本的采样集,然后基于每个采样集训练出一个基本学习器,再将这些基本学习器进行结合。这就是Bagging的一般流程。在对预测输出进行结合时,Bagging通常使用简单投票法,
    对回归问题使用简单平均法。若分类预测时,出现两个类收到同样票数的情形,则最简单的做法是随机选择一个,也可以进一步考察学习器投票的置信度来确定最终胜者。

      Bagging的算法描述如下图所示。

      随机森林是Bagging的一个扩展变体。随机森林在以决策树为基学习器构建Bagging集成的基础上,进一步在决策树的训练过程中引入了随机属性选择。具体来讲,传统决策树在选择划分属性时,
    在当前节点的属性集合(假设有d个属性)中选择一个最优属性;而在随机森林中,对基决策树的每个节点,先从该节点的属性集合中随机选择一个包含k个属性的子集,然后再从这个子集中选择一个最优属性用于划分。
    这里的参数k控制了随机性的引入程度。若令k=d,则基决策树的构建与传统决策树相同;若令k=1,则是随机选择一个属性用于划分。在MLlib中,有两种选择用于分类,即k=log2(d)k=sqrt(d)
    一种选择用于回归,即k=1/3d。在源码分析中会详细介绍。

      可以看出,随机森林对Bagging只做了小改动,但是与Bagging中基学习器的“多样性”仅仅通过样本扰动(通过对初始训练集采样)而来不同,随机森林中基学习器的多样性不仅来自样本扰动,还来自属性扰动。
    这使得最终集成的泛化性能可通过个体学习器之间差异度的增加而进一步提升。

      随机森林算法在单机环境下很容易实现,但在分布式环境下特别是在Spark平台上,传统单机形式的迭代方式必须要进行相应改进才能适用于分布式环境
    ,这是因为在分布式环境下,数据也是分布式的,算法设计不得当会生成大量的IO操作,例如频繁的网络数据传输,从而影响算法效率。
    因此,在Spark上进行随机森林算法的实现,需要进行一定的优化,Spark中的随机森林算法主要实现了三个优化策略:

    • 切分点抽样统计,如下图所示。在单机环境下的决策树对连续变量进行切分点选择时,一般是通过对特征点进行排序,然后取相邻两个数之间的点作为切分点,这在单机环境下是可行的,但如果在分布式环境下如此操作的话,
      会带来大量的网络传输操作,特别是当数据量达到PB级时,算法效率将极为低下。为避免该问题,Spark中的随机森林在构建决策树时,会对各分区采用一定的子特征策略进行抽样,然后生成各个分区的统计数据,并最终得到切分点。
      (从源代码里面看,是先对样本进行抽样,然后根据抽样样本值出现的次数进行排序,然后再进行切分)。
    1.2
    • 特征装箱(Binning),如下图所示。决策树的构建过程就是对特征的取值不断进行划分的过程,对于离散的特征,如果有M个值,最多有2^(M-1) - 1个划分。如果值是有序的,那么就最多M-1个划分。
      比如年龄特征,有老,中,少3个值,如果无序有2^2-1=3个划分,即老|中,少;老,中|少;老,少|中。;如果是有序的,即按老,中,少的序,那么只有m-1个,即2种划分,老|中,少;老,中|少
      对于连续的特征,其实就是进行范围划分,而划分的点就是split(切分点),划分出的区间就是bin。对于连续特征,理论上split是无数的,在分布环境下不可能取出所有的值,因此它采用的是切点抽样统计方法。
    • 逐层训练(level-wise training),如下图所示。单机版本的决策树生成过程是通过递归调用(本质上是深度优先)的方式构造树,在构造树的同时,需要移动数据,将同一个子节点的数据移动到一起。
      此方法在分布式数据结构上无法有效的执行,而且也无法执行,因为数据太大,无法放在一起,所以在分布式环境下采用的策略是逐层构建树节点(本质上是广度优先),这样遍历所有数据的次数等于所有树中的最大层数。
      每次遍历时,只需要计算每个节点所有切分点统计参数,遍历完后,根据节点的特征划分,决定是否切分,以及如何切分。
    1.4

      下面的例子用于回归。

    1. import org.apache.spark.mllib.tree.RandomForest
    2. import org.apache.spark.mllib.tree.model.RandomForestModel
    3. import org.apache.spark.mllib.util.MLUtils
    4. // Load and parse the data file.
    5. val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
    6. // Split the data into training and test sets (30% held out for testing)
    7. val splits = data.randomSplit(Array(0.7, 0.3))
    8. val (trainingData, testData) = (splits(0), splits(1))
    9. // Train a RandomForest model.
    10. // 空的类别特征信息表示所有的特征都是连续的
    11. val numClasses = 2
    12. val categoricalFeaturesInfo = Map[Int, Int]()
    13. val numTrees = 3 // Use more in practice.
    14. val featureSubsetStrategy = "auto" // Let the algorithm choose.
    15. val impurity = "variance"
    16. val maxDepth = 4
    17. val maxBins = 32
    18. val model = RandomForest.trainRegressor(trainingData, categoricalFeaturesInfo,
    19. numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
    20. // Evaluate model on test instances and compute test error
    21. val labelsAndPredictions = testData.map { point =>
    22. val prediction = model.predict(point.features)
    23. (point.label, prediction)
    24. }
    25. val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
    26. println("Test Mean Squared Error = " + testMSE)
    27. println("Learned regression forest model:\n" + model.toDebugString)

      训练过程简单可以分为两步,第一步是初始化,第二步是迭代构建随机森林。这两大步还分为若干小步,下面会分别介绍这些内容。

    5.1.1 初始化

    1. val retaggedInput = input.retag(classOf[LabeledPoint])
    2. //建立决策树的元数据信息(分裂点位置、箱子数及各箱子包含特征属性的值等)
    3. val metadata =
    4. DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
    5. //找到切分点(splits)及箱子信息(Bins)
    6. //对于连续型特征,利用切分点抽样统计简化计算
    7. //对于离散型特征,如果是无序的,则最多有个 splits=2^(numBins-1)-1 划分
    8. //如果是有序的,则最多有 splits=numBins-1 个划分
    9. val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, metadata)
    10. //转换成树形的 RDD 类型,转换后,所有样本点已经按分裂点条件分到了各自的箱子中
    11. val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
    12. val withReplacement = if (numTrees > 1) true else false
    13. // convertToBaggedRDD 方法使得每棵树就是样本的一个子集
    14. val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput,
    15. strategy.subsamplingRate, numTrees,
    16. withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
    17. //决策树的深度,最大为30
    18. val maxDepth = strategy.maxDepth
    19. //聚合的最大内存
    20. val maxMemoryUsage: Long = strategy.maxMemoryInMB * 1024L * 1024L
    21. val maxMemoryPerNode = {
    22. val featureSubset: Option[Array[Int]] = if (metadata.subsamplingFeatures) {
    23. // Find numFeaturesPerNode largest bins to get an upper bound on memory usage.
    24. Some(metadata.numBins.zipWithIndex.sortBy(- _._1)
    25. .take(metadata.numFeaturesPerNode).map(_._2))
    26. } else {
    27. None
    28. }
    29. //计算聚合操作时节点的内存
    30. RandomForest.aggregateSizeForNode(metadata, featureSubset) * 8L
    31. }

      初始化的第一步就是决策树元数据信息的构建。它的代码如下所示。

    1. def buildMetadata(
    2. input: RDD[LabeledPoint],
    3. strategy: Strategy,
    4. numTrees: Int,
    5. featureSubsetStrategy: String): DecisionTreeMetadata = {
    6. //特征数
    7. val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse {
    8. throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " +
    9. s"but was given by empty one.")
    10. }
    11. val numExamples = input.count()
    12. val numClasses = strategy.algo match {
    13. case Classification => strategy.numClasses
    14. case Regression => 0
    15. }
    16. //最大可能的装箱数
    17. val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt
    18. if (maxPossibleBins < strategy.maxBins) {
    19. logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" +
    20. s" (= number of training instances)")
    21. }
    22. // We check the number of bins here against maxPossibleBins.
    23. // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified
    24. // based on the number of training examples.
    25. //最大分类数要小于最大可能装箱数
    26. //这里categoricalFeaturesInfo是传入的信息,这个map保存特征的类别信息。
    27. //例如,(n->k)表示特征k包含的类别有(0,1,...,k-1)
    28. if (strategy.categoricalFeaturesInfo.nonEmpty) {
    29. val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max
    30. val maxCategory =
    31. strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1
    32. require(maxCategoriesPerFeature <= maxPossibleBins,
    33. s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " +
    34. s"number of values in each categorical feature, but categorical feature $maxCategory " +
    35. s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " +
    36. "features with a large number of values, or add more training examples.")
    37. }
    38. val unorderedFeatures = new mutable.HashSet[Int]()
    39. val numBins = Array.fill[Int](numFeatures)(maxPossibleBins)
    40. if (numClasses > 2) {
    41. // 多分类
    42. val maxCategoriesForUnorderedFeature =
    43. ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt
    44. strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
    45. //如果类别特征只有1个类,我们把它看成连续的特征
    46. if (numCategories > 1) {
    47. // Decide if some categorical features should be treated as unordered features,
    48. // which require 2 * ((1 << numCategories - 1) - 1) bins.
    49. // We do this check with log values to prevent overflows in case numCategories is large.
    50. // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins
    51. if (numCategories <= maxCategoriesForUnorderedFeature) {
    52. unorderedFeatures.add(featureIndex)
    53. numBins(featureIndex) = numUnorderedBins(numCategories)
    54. } else {
    55. numBins(featureIndex) = numCategories
    56. }
    57. }
    58. }
    59. } else {
    60. // 二分类或者回归
    61. strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) =>
    62. //如果类别特征只有1个类,我们把它看成连续的特征
    63. if (numCategories > 1) {
    64. numBins(featureIndex) = numCategories
    65. }
    66. }
    67. }
    68. // 设置每个节点的特征数 (对随机森林而言).
    69. val _featureSubsetStrategy = featureSubsetStrategy match {
    70. case "auto" =>
    71. if (numTrees == 1) {//决策树时,使用所有特征
    72. "all"
    73. } else {
    74. if (strategy.algo == Classification) {//分类时,使用开平方
    75. "sqrt"
    76. } else { //回归时,使用1/3的特征
    77. "onethird"
    78. }
    79. }
    80. case _ => featureSubsetStrategy
    81. }
    82. val numFeaturesPerNode: Int = _featureSubsetStrategy match {
    83. case "all" => numFeatures
    84. case "sqrt" => math.sqrt(numFeatures).ceil.toInt
    85. case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt)
    86. case "onethird" => (numFeatures / 3.0).ceil.toInt
    87. }
    88. new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
    89. strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
    90. strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
    91. strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode)
    92. }

      初始化的第二步就是找到切分点(splits)及箱子信息(Bins)。这时,调用了DecisionTree.findSplitsBins方法,进入该方法了解详细信息。

      我们进入findSplitsBinsBySorting方法了解Sort分裂策略的实现。

    1. private def findSplitsBinsBySorting(
    2. input: RDD[LabeledPoint],
    3. metadata: DecisionTreeMetadata,
    4. continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], Array[Array[Bin]]) = {
    5. def findSplits(
    6. featureIndex: Int,
    7. featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
    8. //每个特征分别对应一组切分点位置,这里splits是有序的
    9. // findSplitsForContinuousFeature 返回连续特征的所有切分位置
    10. val featureSplits = findSplitsForContinuousFeature(
    11. featureSamples.toArray,
    12. metadata,
    13. featureIndex)
    14. featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil))
    15. }
    16. //存放切分点位置对应的箱子信息
    17. val bins = {
    18. //采用最小阈值 Double.MinValue 作为最左边的分裂位置并进行装箱
    19. val lowSplit = new DummyLowSplit(featureIndex, Continuous)
    20. //最后一个箱子的计算采用最大阈值 Double.MaxValue 作为最右边的切分位置
    21. // tack the dummy splits on either side of the computed splits
    22. val allSplits = lowSplit +: splits.toSeq :+ highSplit
    23. //将切分点两两结合成一个箱子
    24. allSplits.sliding(2).map {
    25. case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue)
    26. }.toArray
    27. }
    28. (featureIndex, (splits, bins))
    29. }
    30. val continuousSplits = {
    31. // reduce the parallelism for split computations when there are less
    32. // continuous features than input partitions. this prevents tasks from
    33. // being spun up that will definitely do no work.
    34. val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
    35. input
    36. .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
    37. .groupByKey(numPartitions)
    38. .map { case (k, v) => findSplits(k, v) }
    39. .collectAsMap()
    40. }
    41. val numFeatures = metadata.numFeatures
    42. //遍历所有特征
    43. val (splits, bins) = Range(0, numFeatures).unzip {
    44. //处理连续特征的情况
    45. case i if metadata.isContinuous(i) =>
    46. val (split, bin) = continuousSplits(i)
    47. metadata.setNumSplits(i, split.length)
    48. (split, bin)
    49. //处理离散特征且无序的情况
    50. case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
    51. // Unordered features
    52. // 2^(maxFeatureValue - 1) - 1 combinations
    53. val featureArity = metadata.featureArity(i)
    54. val split = Range(0, metadata.numSplits(i)).map { splitIndex =>
    55. val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
    56. new Split(i, Double.MinValue, Categorical, categories)
    57. }
    58. // For unordered categorical features, there is no need to construct the bins.
    59. // since there is a one-to-one correspondence between the splits and the bins.
    60. (split.toArray, Array.empty[Bin])
    61. //处理离散特征且有序的情况
    62. case i if metadata.isCategorical(i) =>
    63. //有序特征无需处理,箱子与特征值对应
    64. // Ordered features
    65. // Bins correspond to feature values, so we do not need to compute splits or bins
    66. // beforehand. Splits are constructed as needed during training.
    67. (Array.empty[Split], Array.empty[Bin])
    68. }
    69. (splits.toArray, bins.toArray)
    70. }
    1. private[tree] def findSplitsForContinuousFeature(
    2. featureSamples: Array[Double],
    3. metadata: DecisionTreeMetadata,
    4. featureIndex: Int): Array[Double] = {
    5. val splits = {
    6. //切分数是bin的数量减1,即m-1
    7. val numSplits = metadata.numSplits(featureIndex)
    8. // (特征,特征出现的次数)
    9. val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
    10. m + ((x, m.getOrElse(x, 0) + 1))
    11. }
    12. // 根据特征进行排序
    13. val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
    14. // if possible splits is not enough or just enough, just return all possible splits
    15. val possibleSplits = valueCounts.length
    16. //如果特征数小于切分数,所有特征均作为切分点
    17. if (possibleSplits <= numSplits) {
    18. valueCounts.map(_._1)
    19. } else {
    20. // 等频切分
    21. // 切分点之间的步长
    22. val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
    23. val splitsBuilder = Array.newBuilder[Double]
    24. var index = 1
    25. // currentCount: sum of counts of values that have been visited
    26. //第一个特征的出现次数
    27. var currentCount = valueCounts(0)._2
    28. // targetCount: target value for `currentCount`.
    29. // If `currentCount` is closest value to `targetCount`,
    30. // then current value is a split threshold.
    31. // After finding a split threshold, `targetCount` is added by stride.
    32. // 如果currentCount离targetCount最近,那么当前值是切分点
    33. var targetCount = stride
    34. while (index < valueCounts.length) {
    35. val previousCount = currentCount
    36. currentCount += valueCounts(index)._2
    37. val previousGap = math.abs(previousCount - targetCount)
    38. val currentGap = math.abs(currentCount - targetCount)
    39. // If adding count of current value to currentCount
    40. // makes the gap between currentCount and targetCount smaller,
    41. // previous value is a split threshold.
    42. if (previousGap < currentGap) {
    43. splitsBuilder += valueCounts(index - 1)._1
    44. targetCount += stride
    45. }
    46. index += 1
    47. }
    48. splitsBuilder.result()
    49. }
    50. }
    51. splits
    52. }

       在if判断里每步前进stride个样本,累加在targetCount中。while循环逐次把每个特征值的个数加到currentCount里,计算前一次previousCount和这次currentCounttargetCount的距离,有3种情况,一种是precur都在target左边,肯定是cur小,继续循环,进入第二种情况;第二种一左一右,如果pre小,肯定是pre是最好的分割点,如果cur还是小,继续循环步进,进入第三种情况;第三种就是都在右边,显然是pre小。因此if的判断条件pre<cur,只要满足肯定就是split。整体下来的效果就能找到离target最近的一个特征值。

    5.1.2 迭代构建随机森林

    1. //节点是否使用缓存,节点 ID 从 1 开始,1 即为这颗树的根节点,左节点为 2,右节点为 3,依次递增下去
    2. val nodeIdCache = if (strategy.useNodeIdCache) {
    3. Some(NodeIdCache.init(
    4. data = baggedInput,
    5. numTrees = numTrees,
    6. checkpointInterval = strategy.checkpointInterval,
    7. initVal = 1))
    8. } else {
    9. None
    10. }
    11. // FIFO queue of nodes to train: (treeIndex, node)
    12. val nodeQueue = new mutable.Queue[(Int, Node)]()
    13. val rng = new scala.util.Random()
    14. rng.setSeed(seed)
    15. // Allocate and queue root nodes.
    16. //创建树的根节点
    17. val topNodes: Array[Node] = Array.fill[Node](numTrees)(Node.emptyNode(nodeIndex = 1))
    18. //将(树的索引,树的根节点)入队,树索引从 0 开始,根节点从 1 开始
    19. Range(0, numTrees).foreach(treeIndex => nodeQueue.enqueue((treeIndex, topNodes(treeIndex))))
    20. while (nodeQueue.nonEmpty) {
    21. // Collect some nodes to split, and choose features for each node (if subsampling).
    22. // Each group of nodes may come from one or multiple trees, and at multiple levels.
    23. // 取得每个树所有需要切分的节点,nodesForGroup表示需要切分的节点
    24. val (nodesForGroup, treeToNodeToIndexInfo) =
    25. RandomForest.selectNodesToSplit(nodeQueue, maxMemoryUsage, metadata, rng)
    26. //找出最优切点
    27. DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
    28. treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
    29. }

      这里有两点需要重点介绍,第一点是取得每个树所有需要切分的节点,通过RandomForest.selectNodesToSplit方法实现;第二点是找出最优的切分,通过DecisionTree.findBestSplits方法实现。下面分别介绍这两点。

    • 取得每个树所有需要切分的节点
    • 选中最优切分
    1. //所有可切分的节点
    2. val nodes = new Array[Node](numNodes)
    3. nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
    4. nodesForTree.foreach { node =>
    5. nodes(treeToNodeToIndexInfo(treeIndex)(node.id).nodeIndexInGroup) = node
    6. }
    7. }
    8. // In each partition, iterate all instances and compute aggregate stats for each node,
    9. // yield an (nodeIndex, nodeAggregateStats) pair for each node.
    10. // After a `reduceByKey` operation,
    11. // stats of a node will be shuffled to a particular partition and be combined together,
    12. // then best splits for nodes are found there.
    13. // Finally, only best Splits for nodes are collected to driver to construct decision tree.
    14. //获取节点对应的特征
    15. val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
    16. val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
    17. val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
    18. input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
    19. // Construct a nodeStatsAggregators array to hold node aggregate stats,
    20. // each node will have a nodeStatsAggregator
    21. val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
    22. //节点对应的特征集
    23. val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
    24. Some(nodeToFeatures(nodeIndex))
    25. }
    26. // DTStatsAggregator,其中引用了 ImpurityAggregator,给出计算不纯度 impurity 的逻辑
    27. new DTStatsAggregator(metadata, featuresForNode)
    28. }
    29. // 迭代当前分区的所有对象,更新聚合统计信息,统计信息即采样数据的权重值
    30. points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
    31. // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
    32. // which can be combined with other partition using `reduceByKey`
    33. nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
    34. }
    35. } else {
    36. input.mapPartitions { points =>
    37. // Construct a nodeStatsAggregators array to hold node aggregate stats,
    38. val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
    39. //节点对应的特征集
    40. val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
    41. Some(nodeToFeatures(nodeIndex))
    42. }
    43. // DTStatsAggregator,其中引用了 ImpurityAggregator,给出计算不纯度 impurity 的逻辑
    44. new DTStatsAggregator(metadata, featuresForNode)
    45. }
    46. // 迭代当前分区的所有对象,更新聚合统计信息
    47. points.foreach(binSeqOp(nodeStatsAggregators, _))
    48. // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
    49. // which can be combined with other partition using `reduceByKey`
    50. nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
    51. }
    52. val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
    53. .map { case (nodeIndex, aggStats) =>
    54. val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
    55. nodeToFeatures(nodeIndex)
    56. }
    57. // find best split for each node
    58. val (split: Split, stats: InformationGainStats, predict: Predict) =
    59. binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
    60. (nodeIndex, (split, stats, predict))
    61. }.collectAsMap()

      该方法中的关键是对binsToBestSplit方法的调用,binsToBestSplit方法代码如下:

    1. private def binsToBestSplit(
    2. binAggregates: DTStatsAggregator,
    3. splits: Array[Array[Split]],
    4. featuresForNode: Option[Array[Int]],
    5. node: Node): (Split, InformationGainStats, Predict) = {
    6. // 如果当前节点是根节点,计算预测和不纯度
    7. val level = Node.indexToLevel(node.id)
    8. var predictWithImpurity: Option[(Predict, Double)] = if (level == 0) {
    9. None
    10. } else {
    11. Some((node.predict, node.impurity))
    12. }
    13. // 对各特征及切分点,计算其信息增益并从中选择最优 (feature, split)
    14. val (bestSplit, bestSplitStats) =
    15. Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
    16. val featureIndex = if (featuresForNode.nonEmpty) {
    17. featuresForNode.get.apply(featureIndexIdx)
    18. } else {
    19. featureIndexIdx
    20. }
    21. val numSplits = binAggregates.metadata.numSplits(featureIndex)
    22. //特征为连续值的情况
    23. if (binAggregates.metadata.isContinuous(featureIndex)) {
    24. // Cumulative sum (scanLeft) of bin statistics.
    25. // Afterwards, binAggregates for a bin is the sum of aggregates for
    26. // that bin + all preceding bins.
    27. val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
    28. var splitIndex = 0
    29. while (splitIndex < numSplits) {
    30. binAggregates.mergeForFeature(nodeFeatureOffset, splitIndex + 1, splitIndex)
    31. splitIndex += 1
    32. }
    33. // Find best split.
    34. val (bestFeatureSplitIndex, bestFeatureGainStats) =
    35. Range(0, numSplits).map { case splitIdx =>
    36. //计算 leftChild 及 rightChild 子节点的 impurity
    37. val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
    38. val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
    39. rightChildStats.subtract(leftChildStats)
    40. //求 impurity 的预测值,采用的是平均值计算
    41. predictWithImpurity = Some(predictWithImpurity.getOrElse(
    42. calculatePredictImpurity(leftChildStats, rightChildStats)))
    43. //求信息增益 information gain 值,用于评估切分点是否最优,请参考决策树中1.4.4章节的介绍
    44. val gainStats = calculateGainForSplit(leftChildStats,
    45. rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
    46. (splitIdx, gainStats)
    47. }.maxBy(_._2.gain)
    48. (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
    49. }
    50. //无序离散特征时的情况
    51. else if (binAggregates.metadata.isUnordered(featureIndex)) {
    52. // Unordered categorical feature
    53. val (leftChildOffset, rightChildOffset) =
    54. binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
    55. val (bestFeatureSplitIndex, bestFeatureGainStats) =
    56. Range(0, numSplits).map { splitIndex =>
    57. val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
    58. val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
    59. predictWithImpurity = Some(predictWithImpurity.getOrElse(
    60. calculatePredictImpurity(leftChildStats, rightChildStats)))
    61. val gainStats = calculateGainForSplit(leftChildStats,
    62. rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
    63. (splitIndex, gainStats)
    64. }.maxBy(_._2.gain)
    65. (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
    66. } else {//有序离散特征时的情况
    67. // Ordered categorical feature
    68. val nodeFeatureOffset = binAggregates.getFeatureOffset(featureIndexIdx)
    69. val numBins = binAggregates.metadata.numBins(featureIndex)
    70. /* Each bin is one category (feature value).
    71. * The bins are ordered based on centroidForCategories, and this ordering determines which
    72. * splits are considered. (With K categories, we consider K - 1 possible splits.)
    73. *
    74. * centroidForCategories is a list: (category, centroid)
    75. */
    76. //多元分类时的情况
    77. val centroidForCategories = if (binAggregates.metadata.isMulticlass) {
    78. // For categorical variables in multiclass classification,
    79. // the bins are ordered by the impurity of their corresponding labels.
    80. Range(0, numBins).map { case featureValue =>
    81. val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
    82. val centroid = if (categoryStats.count != 0) {
    83. // impurity 求的就是均方差
    84. categoryStats.calculate()
    85. } else {
    86. Double.MaxValue
    87. }
    88. (featureValue, centroid)
    89. }
    90. } else { // 回归或二元分类时的情况
    91. // For categorical variables in regression and binary classification,
    92. // the bins are ordered by the centroid of their corresponding labels.
    93. Range(0, numBins).map { case featureValue =>
    94. val categoryStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
    95. val centroid = if (categoryStats.count != 0) {
    96. //求的就是平均值作为 impurity
    97. categoryStats.predict
    98. } else {
    99. Double.MaxValue
    100. }
    101. (featureValue, centroid)
    102. }
    103. }
    104. // bins sorted by centroids
    105. val categoriesSortedByCentroid = centroidForCategories.toList.sortBy(_._2)
    106. // Cumulative sum (scanLeft) of bin statistics.
    107. // Afterwards, binAggregates for a bin is the sum of aggregates for
    108. // that bin + all preceding bins.
    109. var splitIndex = 0
    110. while (splitIndex < numSplits) {
    111. val currentCategory = categoriesSortedByCentroid(splitIndex)._1
    112. val nextCategory = categoriesSortedByCentroid(splitIndex + 1)._1
    113. //将两个箱子的状态信息进行合并
    114. binAggregates.mergeForFeature(nodeFeatureOffset, nextCategory, currentCategory)
    115. splitIndex += 1
    116. }
    117. // lastCategory = index of bin with total aggregates for this (node, feature)
    118. val lastCategory = categoriesSortedByCentroid.last._1
    119. // Find best split.
    120. //通过信息增益值选择最优切分点
    121. val (bestFeatureSplitIndex, bestFeatureGainStats) =
    122. Range(0, numSplits).map { splitIndex =>
    123. val featureValue = categoriesSortedByCentroid(splitIndex)._1
    124. val leftChildStats =
    125. binAggregates.getImpurityCalculator(nodeFeatureOffset, featureValue)
    126. val rightChildStats =
    127. binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
    128. rightChildStats.subtract(leftChildStats)
    129. predictWithImpurity = Some(predictWithImpurity.getOrElse(
    130. calculatePredictImpurity(leftChildStats, rightChildStats)))
    131. val gainStats = calculateGainForSplit(leftChildStats,
    132. rightChildStats, binAggregates.metadata, predictWithImpurity.get._2)
    133. (splitIndex, gainStats)
    134. }.maxBy(_._2.gain)
    135. val categoriesForSplit =
    136. categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
    137. val bestFeatureSplit =
    138. new Split(featureIndex, Double.MinValue, Categorical, categoriesForSplit)
    139. (bestFeatureSplit, bestFeatureGainStats)
    140. }
    141. }.maxBy(_._2.gain)
    142. (bestSplit, bestSplitStats, predictWithImpurity.get._1)
    143. }

    5.2 预测分析

      在利用随机森林进行预测时,调用的predict方法扩展自TreeEnsembleModel,它是树结构组合模型的表示,其核心代码如下所示:

    1. //不同的策略采用不同的预测方法
    2. def predict(features: Vector): Double = {
    3. (algo, combiningStrategy) match {
    4. case (Regression, Sum) =>
    5. predictBySumming(features)
    6. case (Regression, Average) =>
    7. predictBySumming(features) / sumWeights
    8. case (Classification, Sum) => // binary classification
    9. val prediction = predictBySumming(features)
    10. // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
    11. if (prediction > 0.0) 1.0 else 0.0
    12. case (Classification, Vote) =>
    13. predictByVoting(features)
    14. case _ =>
    15. throw new IllegalArgumentException()
    16. }
    17. }
    18. private def predictBySumming(features: Vector): Double = {
    19. val treePredictions = trees.map(_.predict(features))
    20. //两个向量的内集
    21. blas.ddot(numTrees, treePredictions, 1, treeWeights, 1)
    22. }
    23. //通过投票选举
    24. private def predictByVoting(features: Vector): Double = {
    25. val votes = mutable.Map.empty[Int, Double]
    26. trees.view.zip(treeWeights).foreach { case (tree, weight) =>
    27. val prediction = tree.predict(features).toInt
    28. votes(prediction) = votes.getOrElse(prediction, 0.0) + weight
    29. }
    30. votes.maxBy(_._2)._1

    参考文献

    【1】机器学习.周志华

    【3】