spark decimalExpressions 源码

  • 2022-10-20
  • 浏览 (207)

spark decimalExpressions 代码

文件路径:/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

/**
 * Return the unscaled Long value of a Decimal, assuming it fits in a Long.
 * Note: this expression is internal and created only by the optimizer,
 * we don't need to do type check for it.
 */
case class UnscaledValue(child: Expression) extends UnaryExpression with NullIntolerant {

  override def dataType: DataType = LongType
  override def toString: String = s"UnscaledValue($child)"

  protected override def nullSafeEval(input: Any): Any =
    input.asInstanceOf[Decimal].toUnscaledLong

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()")
  }

  override protected def withNewChildInternal(newChild: Expression): UnscaledValue =
    copy(child = newChild)
}

/**
 * Create a Decimal from an unscaled Long value.
 * Note: this expression is internal and created only by the optimizer,
 * we don't need to do type check for it.
 */
case class MakeDecimal(
    child: Expression,
    precision: Int,
    scale: Int,
    nullOnOverflow: Boolean) extends UnaryExpression with NullIntolerant {

  def this(child: Expression, precision: Int, scale: Int) = {
    this(child, precision, scale, !SQLConf.get.ansiEnabled)
  }

  override def dataType: DataType = DecimalType(precision, scale)
  override def nullable: Boolean = child.nullable || nullOnOverflow
  override def toString: String = s"MakeDecimal($child,$precision,$scale)"

  protected override def nullSafeEval(input: Any): Any = {
    val longInput = input.asInstanceOf[Long]
    val result = new Decimal()
    if (nullOnOverflow) {
      result.setOrNull(longInput, precision, scale)
    } else {
      result.set(longInput, precision, scale)
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    nullSafeCodeGen(ctx, ev, eval => {
      val setMethod = if (nullOnOverflow) {
        "setOrNull"
      } else {
        "set"
      }
      val setNull = if (nullable) {
        s"${ev.isNull} = ${ev.value} == null;"
      } else {
        ""
      }
      s"""
         |${ev.value} = (new Decimal()).$setMethod($eval, $precision, $scale);
         |$setNull
         |""".stripMargin
    })
  }

  override protected def withNewChildInternal(newChild: Expression): MakeDecimal =
    copy(child = newChild)
}

object MakeDecimal {
  def apply(child: Expression, precision: Int, scale: Int): MakeDecimal = {
    new MakeDecimal(child, precision, scale)
  }
}

/**
 * Rounds the decimal to given scale and check whether the decimal can fit in provided precision
 * or not. If not, if `nullOnOverflow` is `true`, it returns `null`; otherwise an
 * `ArithmeticException` is thrown.
 */
case class CheckOverflow(
    child: Expression,
    dataType: DecimalType,
    nullOnOverflow: Boolean) extends UnaryExpression with SupportQueryContext {

  override def nullable: Boolean = true

  override def nullSafeEval(input: Any): Any =
    input.asInstanceOf[Decimal].toPrecision(
      dataType.precision,
      dataType.scale,
      Decimal.ROUND_HALF_UP,
      nullOnOverflow,
      getContextOrNull())

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
    nullSafeCodeGen(ctx, ev, eval => {
      // scalastyle:off line.size.limit
      s"""
         |${ev.value} = $eval.toPrecision(
         |  ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode);
         |${ev.isNull} = ${ev.value} == null;
       """.stripMargin
      // scalastyle:on line.size.limit
    })
  }

  override def toString: String = s"CheckOverflow($child, $dataType)"

  override def sql: String = child.sql

  override protected def withNewChildInternal(newChild: Expression): CheckOverflow =
    copy(child = newChild)

  override def initQueryContext(): Option[SQLQueryContext] = if (!nullOnOverflow) {
    Some(origin.context)
  } else {
    None
  }
}

// A variant `CheckOverflow`, which treats null as overflow. This is necessary in `Sum`.
case class CheckOverflowInSum(
    child: Expression,
    dataType: DecimalType,
    nullOnOverflow: Boolean,
    context: SQLQueryContext) extends UnaryExpression with SupportQueryContext {

  override def nullable: Boolean = true

  override def eval(input: InternalRow): Any = {
    val value = child.eval(input)
    if (value == null) {
      if (nullOnOverflow) null
      else throw QueryExecutionErrors.overflowInSumOfDecimalError(context)
    } else {
      value.asInstanceOf[Decimal].toPrecision(
        dataType.precision,
        dataType.scale,
        Decimal.ROUND_HALF_UP,
        nullOnOverflow,
        context)
    }
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val childGen = child.genCode(ctx)
    val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
    val nullHandling = if (nullOnOverflow) {
      ""
    } else {
      s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
    }
    // scalastyle:off line.size.limit
    val code = code"""
       |${childGen.code}
       |boolean ${ev.isNull} = ${childGen.isNull};
       |Decimal ${ev.value} = null;
       |if (${childGen.isNull}) {
       |  $nullHandling
       |} else {
       |  ${ev.value} = ${childGen.value}.toPrecision(
       |    ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode);
       |  ${ev.isNull} = ${ev.value} == null;
       |}
       |""".stripMargin
    // scalastyle:on line.size.limit

    ev.copy(code = code)
  }

  override def toString: String = s"CheckOverflowInSum($child, $dataType, $nullOnOverflow)"

  override def sql: String = child.sql

  override protected def withNewChildInternal(newChild: Expression): CheckOverflowInSum =
    copy(child = newChild)

  override def initQueryContext(): Option[SQLQueryContext] = Option(context)
}

/**
 * An add expression for decimal values which is only used internally by Sum/Avg.
 *
 * Nota that, this expression does not check overflow which is different with `Add`. When
 * aggregating values, Spark writes the aggregation buffer values to `UnsafeRow` via
 * `UnsafeRowWriter`, which already checks decimal overflow, so we don't need to do it again in the
 * add expression used by Sum/Avg.
 */
case class DecimalAddNoOverflowCheck(
    left: Expression,
    right: Expression,
    override val dataType: DataType) extends BinaryOperator {
  require(dataType.isInstanceOf[DecimalType])

  override def inputType: AbstractDataType = DecimalType
  override def symbol: String = "+"
  private def decimalMethod: String = "$plus"

  private lazy val numeric = TypeUtils.getNumeric(dataType)

  override protected def nullSafeEval(input1: Any, input2: Any): Any =
    numeric.plus(input1, input2)

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
    defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)")

  override protected def withNewChildrenInternal(
      newLeft: Expression, newRight: Expression): DecimalAddNoOverflowCheck =
    copy(left = newLeft, right = newRight)
}

/**
 * A divide expression for decimal values which is only used internally by Avg.
 *
 * It will fail when nullOnOverflow is false follows:
 *   - left (sum in avg) is null due to over the max precision 38,
 *     the right (count in avg) should never be null
 *   - the result of divide is overflow
 */
case class DecimalDivideWithOverflowCheck(
    left: Expression,
    right: Expression,
    override val dataType: DecimalType,
    context: SQLQueryContext,
    nullOnOverflow: Boolean)
  extends BinaryExpression with ExpectsInputTypes with SupportQueryContext {
  override def nullable: Boolean = nullOnOverflow
  override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, DecimalType)
  override def initQueryContext(): Option[SQLQueryContext] = Option(context)
  def decimalMethod: String = "$div"

  override def eval(input: InternalRow): Any = {
    val value1 = left.eval(input)
    if (value1 == null) {
      if (nullOnOverflow)  {
        null
      } else {
        throw QueryExecutionErrors.overflowInSumOfDecimalError(getContextOrNull())
      }
    } else {
      val value2 = right.eval(input)
      dataType.fractional.asInstanceOf[Fractional[Any]].div(value1, value2).asInstanceOf[Decimal]
        .toPrecision(dataType.precision, dataType.scale, Decimal.ROUND_HALF_UP, nullOnOverflow,
          getContextOrNull())
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
    val nullHandling = if (nullOnOverflow) {
      ""
    } else {
      s"throw QueryExecutionErrors.overflowInSumOfDecimalError($errorContextCode);"
    }

    val eval1 = left.genCode(ctx)
    val eval2 = right.genCode(ctx)

    // scalastyle:off line.size.limit
    val code =
      code"""
         |${eval1.code}
         |${eval2.code}
         |boolean ${ev.isNull} = ${eval1.isNull};
         |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
         |if (${eval1.isNull}) {
         |  $nullHandling
         |} else {
         |  ${ev.value} = ${eval1.value}.$decimalMethod(${eval2.value}).toPrecision(
         |      ${dataType.precision}, ${dataType.scale}, Decimal.ROUND_HALF_UP(), $nullOnOverflow, $errorContextCode);
         |  ${ev.isNull} = ${ev.value} == null;
         |}
      """.stripMargin
    // scalastyle:on line.size.limit
    ev.copy(code = code)
  }

  override protected def withNewChildrenInternal(
      newLeft: Expression, newRight: Expression): Expression = {
    copy(left = newLeft, right = newRight)
  }
}

相关信息

spark 源码目录

相关文章

spark AliasHelper 源码

spark ApplyFunctionExpression 源码

spark AttributeSet 源码

spark BloomFilterMightContain 源码

spark BoundAttribute 源码

spark CallMethodViaReflection 源码

spark Cast 源码

spark CodeGeneratorWithInterpretedFallback 源码

spark DynamicPruning 源码

spark EquivalentExpressions 源码

0  赞