spark AQEPropagateEmptyRelation 源码

  • 2022-10-20
  • 浏览 (390)

spark AQEPropagateEmptyRelation 代码

文件路径:/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelationBase
import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL}
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys

/**
 * This rule runs in the AQE optimizer and optimizes more cases
 * compared to [[PropagateEmptyRelationBase]]:
 * 1. Join is single column NULL-aware anti join (NAAJ)
 *    Broadcasted [[HashedRelation]] is [[HashedRelationWithAllNullKeys]]. Eliminate join to an
 *    empty [[LocalRelation]].
 */
object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
  override protected def isEmpty(plan: LogicalPlan): Boolean =
    super.isEmpty(plan) || (!isRootRepartition(plan) && getEstimatedRowCount(plan).contains(0))

  override protected def nonEmpty(plan: LogicalPlan): Boolean =
    super.nonEmpty(plan) || getEstimatedRowCount(plan).exists(_ > 0)

  private def isRootRepartition(plan: LogicalPlan): Boolean = plan match {
    case l: LogicalQueryStage if l.getTagValue(ROOT_REPARTITION).isDefined => true
    case _ => false
  }

  // The returned value follows:
  //   - 0 means the plan must produce 0 row
  //   - positive value means an estimated row count which can be over-estimated
  //   - none means the plan has not materialized or the plan can not be estimated
  private def getEstimatedRowCount(plan: LogicalPlan): Option[BigInt] = plan match {
    case LogicalQueryStage(_, stage: QueryStageExec) if stage.isMaterialized =>
      stage.getRuntimeStatistics.rowCount

    case LogicalQueryStage(_, agg: BaseAggregateExec) if agg.groupingExpressions.nonEmpty &&
      agg.child.isInstanceOf[QueryStageExec] =>
      val stage = agg.child.asInstanceOf[QueryStageExec]
      if (stage.isMaterialized) {
        stage.getRuntimeStatistics.rowCount
      } else {
        None
      }

    case _ => None
  }

  private def isRelationWithAllNullKeys(plan: LogicalPlan): Boolean = plan match {
    case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.isMaterialized =>
      stage.broadcast.relationFuture.get().value == HashedRelationWithAllNullKeys
    case _ => false
  }

  private def eliminateSingleColumnNullAwareAntiJoin: PartialFunction[LogicalPlan, LogicalPlan] = {
    case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) if isRelationWithAllNullKeys(j.right) =>
      empty(j)
  }

  override protected def userSpecifiedRepartition(p: LogicalPlan): Boolean = p match {
    case LogicalQueryStage(_, ShuffleQueryStageExec(_, shuffle: ShuffleExchangeLike, _))
      if shuffle.shuffleOrigin == REPARTITION_BY_COL ||
        shuffle.shuffleOrigin == REPARTITION_BY_NUM => true
    case _ => false
  }

  override protected def applyInternal(p: LogicalPlan): LogicalPlan = p.transformUpWithPruning(
    // LOCAL_RELATION and TRUE_OR_FALSE_LITERAL pattern are matched at
    // `PropagateEmptyRelationBase.commonApplyFunc`
    // LOGICAL_QUERY_STAGE pattern is matched at `PropagateEmptyRelationBase.commonApplyFunc`
    // and `AQEPropagateEmptyRelation.eliminateSingleColumnNullAwareAntiJoin`
    // Note that, We can not specify ruleId here since the LogicalQueryStage is not immutable.
    _.containsAnyPattern(LOGICAL_QUERY_STAGE, LOCAL_RELATION, TRUE_OR_FALSE_LITERAL)) {
    eliminateSingleColumnNullAwareAntiJoin.orElse(commonApplyFunc)
  }
}

相关信息

spark 源码目录

相关文章

spark AQEOptimizer 源码

spark AQEShuffleReadExec 源码

spark AQEShuffleReadRule 源码

spark AQEUtils 源码

spark AdaptiveRulesHolder 源码

spark AdaptiveSparkPlanExec 源码

spark AdaptiveSparkPlanHelper 源码

spark AdjustShuffleExchangePosition 源码

spark CoalesceShufflePartitions 源码

spark DynamicJoinSelection 源码

0  赞