spark mathExpressions 源码
spark mathExpressions 代码
文件路径:/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.{lang => jl}
import java.util.Locale
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{NumberConverter, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
/**
* A leaf expression specifically for math constants. Math constants expect no input.
*
* There is no code generation because they should get constant folded by the optimizer.
*
* @param c The math constant.
* @param name The short name of the function
*/
abstract class LeafMathExpression(c: Double, name: String)
extends LeafExpression with CodegenFallback with Serializable {
override def dataType: DataType = DoubleType
override def foldable: Boolean = true
override def nullable: Boolean = false
override def toString: String = s"$name()"
override def prettyName: String = name
override def eval(input: InternalRow): Any = c
}
/**
* A unary expression specifically for math functions. Math Functions expect a specific type of
* input format, therefore these functions extend `ExpectsInputTypes`.
*
* @param f The math function.
* @param name The short name of the function
*/
abstract class UnaryMathExpression(val f: Double => Double, name: String)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override def nullable: Boolean = true
override def toString: String = s"$prettyName($child)"
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(name)
protected override def nullSafeEval(input: Any): Any = {
f(input.asInstanceOf[Double])
}
// name of function in java.lang.Math
def funcName: String = name.toLowerCase(Locale.ROOT)
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
}
}
abstract class UnaryLogExpression(f: Double => Double, name: String)
extends UnaryMathExpression(f, name) {
override def nullable: Boolean = true
// values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity
protected val yAsymptote: Double = 0.0
protected override def nullSafeEval(input: Any): Any = {
val d = input.asInstanceOf[Double]
if (d <= yAsymptote) null else f(d)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c =>
s"""
if ($c <= $yAsymptote) {
${ev.isNull} = true;
} else {
${ev.value} = java.lang.StrictMath.${funcName}($c);
}
"""
)
}
}
/**
* A binary expression specifically for math functions that take two `Double`s as input and returns
* a `Double`.
*
* @param f The math function.
* @param name The short name of the function
*/
abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
override def toString: String = s"$prettyName($left, $right)"
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(name)
override def dataType: DataType = DoubleType
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
f(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) =>
s"java.lang.Math.${name.toLowerCase(Locale.ROOT)}($c1, $c2)")
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// Leaf math functions
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* Euler's number. Note that there is no code generation because this is only
* evaluated by the optimizer during constant folding.
*/
@ExpressionDescription(
usage = "_FUNC_() - Returns Euler's number, e.",
examples = """
Examples:
> SELECT _FUNC_();
2.718281828459045
""",
since = "1.5.0",
group = "math_funcs")
case class EulerNumber() extends LeafMathExpression(math.E, "E")
/**
* Pi. Note that there is no code generation because this is only
* evaluated by the optimizer during constant folding.
*/
@ExpressionDescription(
usage = "_FUNC_() - Returns pi.",
examples = """
Examples:
> SELECT _FUNC_();
3.141592653589793
""",
since = "1.5.0",
group = "math_funcs")
case class Pi() extends LeafMathExpression(math.Pi, "PI")
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// Unary math functions
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the inverse cosine (a.k.a. arc cosine) of `expr`, as if computed by
`java.lang.Math._FUNC_`.
""",
examples = """
Examples:
> SELECT _FUNC_(1);
0.0
> SELECT _FUNC_(2);
NaN
""",
since = "1.4.0",
group = "math_funcs")
case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") {
override protected def withNewChildInternal(newChild: Expression): Acos = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the inverse sine (a.k.a. arc sine) the arc sin of `expr`,
as if computed by `java.lang.Math._FUNC_`.
""",
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
> SELECT _FUNC_(2);
NaN
""",
since = "1.4.0",
group = "math_funcs")
case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") {
override protected def withNewChildInternal(newChild: Expression): Asin = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the inverse tangent (a.k.a. arc tangent) of `expr`, as if computed by
`java.lang.Math._FUNC_`
""",
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
""",
since = "1.4.0",
group = "math_funcs")
case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") {
override protected def withNewChildInternal(newChild: Expression): Atan = copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the cube root of `expr`.",
examples = """
Examples:
> SELECT _FUNC_(27.0);
3.0
""",
since = "1.4.0",
group = "math_funcs")
case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") {
override protected def withNewChildInternal(newChild: Expression): Cbrt = copy(child = newChild)
}
case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") {
override def dataType: DataType = child.dataType match {
case dt @ DecimalType.Fixed(_, 0) => dt
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision - scale + 1, 0)
case _ => LongType
}
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DoubleType, DecimalType, LongType))
protected override def nullSafeEval(input: Any): Any = child.dataType match {
case LongType => input.asInstanceOf[Long]
case DoubleType => f(input.asInstanceOf[Double]).toLong
case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.dataType match {
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
case DecimalType.Fixed(_, _) =>
defineCodeGen(ctx, ev, c => s"$c.ceil()")
case LongType => defineCodeGen(ctx, ev, c => s"$c")
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
}
}
override protected def withNewChildInternal(newChild: Expression): Ceil = copy(child = newChild)
}
trait CeilFloorExpressionBuilderBase extends ExpressionBuilder {
protected def buildWithOneParam(param: Expression): Expression
protected def buildWithTwoParams(param1: Expression, param2: Expression): Expression
override def build(funcName: String, expressions: Seq[Expression]): Expression = {
val numArgs = expressions.length
if (numArgs == 1) {
buildWithOneParam(expressions.head)
} else if (numArgs == 2) {
val scale = expressions(1)
if (!(scale.foldable && scale.dataType == IntegerType)) {
throw QueryCompilationErrors.requireLiteralParameter(funcName, "scale", "int")
}
if (scale.eval() == null) {
throw QueryCompilationErrors.requireLiteralParameter(funcName, "scale", "int")
}
buildWithTwoParams(expressions(0), scale)
} else {
throw QueryCompilationErrors.invalidFunctionArgumentNumberError(Seq(2), funcName, numArgs)
}
}
}
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr[, scale]) - Returns the smallest number after rounding up that is not smaller than `expr`. An optional `scale` parameter can be specified to control the rounding behavior.",
examples = """
Examples:
> SELECT _FUNC_(-0.1);
0
> SELECT _FUNC_(5);
5
> SELECT _FUNC_(3.1411, 3);
3.142
> SELECT _FUNC_(3.1411, -3);
1000
""",
since = "3.3.0",
group = "math_funcs")
// scalastyle:on line.size.limit
object CeilExpressionBuilder extends CeilFloorExpressionBuilderBase {
override protected def buildWithOneParam(param: Expression): Expression = Ceil(param)
override protected def buildWithTwoParams(param1: Expression, param2: Expression): Expression =
RoundCeil(param1, param2)
}
case class RoundCeil(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.CEILING, "ROUND_CEILING") {
override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, IntegerType)
override def nodeName: String = "ceil"
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): RoundCeil =
copy(child = newLeft, scale = newRight)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the cosine of `expr`, as if computed by
`java.lang.Math._FUNC_`.
""",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(0);
1.0
""",
since = "1.4.0",
group = "math_funcs")
case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") {
override protected def withNewChildInternal(newChild: Expression): Cos = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the secant of `expr`, as if computed by `1/java.lang.Math.cos`.
""",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(0);
1.0
""",
since = "3.3.0",
group = "math_funcs")
case class Sec(child: Expression)
extends UnaryMathExpression((x: Double) => 1 / math.cos(x), "SEC") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.cos($c);")
}
override protected def withNewChildInternal(newChild: Expression): Sec = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the hyperbolic cosine of `expr`, as if computed by
`java.lang.Math._FUNC_`.
""",
arguments = """
Arguments:
* expr - hyperbolic angle
""",
examples = """
Examples:
> SELECT _FUNC_(0);
1.0
""",
since = "1.4.0",
group = "math_funcs")
case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") {
override protected def withNewChildInternal(newChild: Expression): Cosh = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns inverse hyperbolic cosine of `expr`.
""",
examples = """
Examples:
> SELECT _FUNC_(1);
0.0
> SELECT _FUNC_(0);
NaN
""",
since = "3.0.0",
group = "math_funcs")
case class Acosh(child: Expression)
extends UnaryMathExpression((x: Double) => StrictMath.log(x + math.sqrt(x * x - 1.0)), "ACOSH") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev,
c => s"java.lang.StrictMath.log($c + java.lang.Math.sqrt($c * $c - 1.0))")
}
override protected def withNewChildInternal(newChild: Expression): Acosh = copy(child = newChild)
}
/**
* Convert a num from one base to another
*
* @param numExpr the number to be converted
* @param fromBaseExpr from which base
* @param toBaseExpr to which base
*/
@ExpressionDescription(
usage = "_FUNC_(num, from_base, to_base) - Convert `num` from `from_base` to `to_base`.",
examples = """
Examples:
> SELECT _FUNC_('100', 2, 10);
4
> SELECT _FUNC_(-10, 16, -10);
-16
""",
since = "1.5.0",
group = "math_funcs")
case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def first: Expression = numExpr
override def second: Expression = fromBaseExpr
override def third: Expression = toBaseExpr
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
override def dataType: DataType = StringType
override def nullable: Boolean = true
override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = {
NumberConverter.convert(
num.asInstanceOf[UTF8String].trim().getBytes,
fromBase.asInstanceOf[Int],
toBase.asInstanceOf[Int])
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val numconv = NumberConverter.getClass.getName.stripSuffix("$")
nullSafeCodeGen(ctx, ev, (num, from, to) =>
s"""
${ev.value} = $numconv.convert($num.trim().getBytes(), $from, $to);
if (${ev.value} == null) {
${ev.isNull} = true;
}
"""
)
}
override protected def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
copy(numExpr = newFirst, fromBaseExpr = newSecond, toBaseExpr = newThird)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns e to the power of `expr`.",
examples = """
Examples:
> SELECT _FUNC_(0);
1.0
""",
since = "1.4.0",
group = "math_funcs")
case class Exp(child: Expression) extends UnaryMathExpression(StrictMath.exp, "EXP") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"java.lang.StrictMath.exp($c)")
}
override protected def withNewChildInternal(newChild: Expression): Exp = copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns exp(`expr`) - 1.",
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
""",
since = "1.4.0",
group = "math_funcs")
case class Expm1(child: Expression) extends UnaryMathExpression(StrictMath.expm1, "EXPM1") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"java.lang.StrictMath.expm1($c)")
}
override protected def withNewChildInternal(newChild: Expression): Expm1 = copy(child = newChild)
}
case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") {
override def dataType: DataType = child.dataType match {
case dt @ DecimalType.Fixed(_, 0) => dt
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision - scale + 1, 0)
case _ => LongType
}
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DoubleType, DecimalType, LongType))
protected override def nullSafeEval(input: Any): Any = child.dataType match {
case LongType => input.asInstanceOf[Long]
case DoubleType => f(input.asInstanceOf[Double]).toLong
case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
child.dataType match {
case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
case DecimalType.Fixed(_, _) =>
defineCodeGen(ctx, ev, c => s"$c.floor()")
case LongType => defineCodeGen(ctx, ev, c => s"$c")
case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
}
}
override protected def withNewChildInternal(newChild: Expression): Floor =
copy(child = newChild)
}
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = " _FUNC_(expr[, scale]) - Returns the largest number after rounding down that is not greater than `expr`. An optional `scale` parameter can be specified to control the rounding behavior.",
examples = """
Examples:
> SELECT _FUNC_(-0.1);
-1
> SELECT _FUNC_(5);
5
> SELECT _FUNC_(3.1411, 3);
3.141
> SELECT _FUNC_(3.1411, -3);
0
""",
since = "3.3.0",
group = "math_funcs")
// scalastyle:on line.size.limit
object FloorExpressionBuilder extends CeilFloorExpressionBuilderBase {
override protected def buildWithOneParam(param: Expression): Expression = Floor(param)
override protected def buildWithTwoParams(param1: Expression, param2: Expression): Expression =
RoundFloor(param1, param2)
}
case class RoundFloor(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.FLOOR, "ROUND_FLOOR") {
override def inputTypes: Seq[AbstractDataType] = Seq(DecimalType, IntegerType)
override def nodeName: String = "floor"
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): RoundFloor =
copy(child = newLeft, scale = newRight)
}
object Factorial {
def factorial(n: Int): Long = {
if (n < factorials.length) factorials(n) else Long.MaxValue
}
private val factorials: Array[Long] = Array[Long](
1,
1,
2,
6,
24,
120,
720,
5040,
40320,
362880,
3628800,
39916800,
479001600,
6227020800L,
87178291200L,
1307674368000L,
20922789888000L,
355687428096000L,
6402373705728000L,
121645100408832000L,
2432902008176640000L
)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the factorial of `expr`. `expr` is [0..20]. Otherwise, null.",
examples = """
Examples:
> SELECT _FUNC_(5);
120
""",
since = "1.5.0",
group = "math_funcs")
case class Factorial(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[DataType] = Seq(IntegerType)
override def dataType: DataType = LongType
// If the value not in the range of [0, 20], it still will be null, so set it to be true here.
override def nullable: Boolean = true
protected override def nullSafeEval(input: Any): Any = {
val value = input.asInstanceOf[jl.Integer]
if (value > 20 || value < 0) {
null
} else {
Factorial.factorial(value)
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
s"""
if ($eval > 20 || $eval < 0) {
${ev.isNull} = true;
} else {
${ev.value} =
org.apache.spark.sql.catalyst.expressions.Factorial.factorial($eval);
}
"""
})
}
override protected def withNewChildInternal(newChild: Expression): Factorial =
copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the natural logarithm (base e) of `expr`.",
examples = """
Examples:
> SELECT _FUNC_(1);
0.0
""",
since = "1.4.0",
group = "math_funcs")
case class Log(child: Expression) extends UnaryLogExpression(StrictMath.log, "LOG") {
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("ln")
override protected def withNewChildInternal(newChild: Expression): Log = copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the logarithm of `expr` with base 2.",
examples = """
Examples:
> SELECT _FUNC_(2);
1.0
""",
since = "1.4.0",
group = "math_funcs")
case class Log2(child: Expression)
extends UnaryLogExpression((x: Double) => StrictMath.log(x) / StrictMath.log(2), "LOG2") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c =>
s"""
if ($c <= $yAsymptote) {
${ev.isNull} = true;
} else {
${ev.value} = java.lang.StrictMath.log($c) / java.lang.StrictMath.log(2);
}
"""
)
}
override protected def withNewChildInternal(newChild: Expression): Log2 = copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the logarithm of `expr` with base 10.",
examples = """
Examples:
> SELECT _FUNC_(10);
1.0
""",
since = "1.4.0",
group = "math_funcs")
case class Log10(child: Expression) extends UnaryLogExpression(StrictMath.log10, "LOG10") {
override protected def withNewChildInternal(newChild: Expression): Log10 = copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns log(1 + `expr`).",
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
""",
since = "1.4.0",
group = "math_funcs")
case class Log1p(child: Expression) extends UnaryLogExpression(StrictMath.log1p, "LOG1P") {
protected override val yAsymptote: Double = -1.0
override protected def withNewChildInternal(newChild: Expression): Log1p = copy(child = newChild)
}
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the double value that is closest in value to the argument and is equal to a mathematical integer.",
examples = """
Examples:
> SELECT _FUNC_(12.3456);
12.0
""",
since = "1.4.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
override def funcName: String = "rint"
override def prettyName: String = getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("rint")
override protected def withNewChildInternal(newChild: Expression): Rint = copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns -1.0, 0.0 or 1.0 as `expr` is negative, 0 or positive.",
examples = """
Examples:
> SELECT _FUNC_(40);
1.0
> SELECT _FUNC_(INTERVAL -'100' YEAR);
-1.0
""",
since = "1.4.0",
group = "math_funcs")
case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DoubleType, YearMonthIntervalType, DayTimeIntervalType))
protected override def nullSafeEval(input: Any): Any = {
f(input.asInstanceOf[Number].doubleValue())
}
override protected def withNewChildInternal(newChild: Expression): Signum = copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the sine of `expr`, as if computed by `java.lang.Math._FUNC_`.",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
""",
since = "1.4.0",
group = "math_funcs")
case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") {
override protected def withNewChildInternal(newChild: Expression): Sin = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the cosecant of `expr`, as if computed by `1/java.lang.Math.sin`.
""",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(1);
1.1883951057781212
""",
since = "3.3.0",
group = "math_funcs")
case class Csc(child: Expression)
extends UnaryMathExpression((x: Double) => 1 / math.sin(x), "CSC") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.sin($c);")
}
override protected def withNewChildInternal(newChild: Expression): Csc = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns hyperbolic sine of `expr`, as if computed by `java.lang.Math._FUNC_`.
""",
arguments = """
Arguments:
* expr - hyperbolic angle
""",
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
""",
since = "1.4.0",
group = "math_funcs")
case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") {
override protected def withNewChildInternal(newChild: Expression): Sinh = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns inverse hyperbolic sine of `expr`.
""",
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
""",
since = "3.0.0",
group = "math_funcs")
case class Asinh(child: Expression)
extends UnaryMathExpression((x: Double) => x match {
case Double.NegativeInfinity => Double.NegativeInfinity
case _ => StrictMath.log(x + math.sqrt(x * x + 1.0)) }, "ASINH") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c =>
s"$c == Double.NEGATIVE_INFINITY ? Double.NEGATIVE_INFINITY : " +
s"java.lang.StrictMath.log($c + java.lang.Math.sqrt($c * $c + 1.0))")
}
override protected def withNewChildInternal(newChild: Expression): Asinh = copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the square root of `expr`.",
examples = """
Examples:
> SELECT _FUNC_(4);
2.0
""",
since = "1.1.1",
group = "math_funcs")
case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") {
override protected def withNewChildInternal(newChild: Expression): Sqrt = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the tangent of `expr`, as if computed by `java.lang.Math._FUNC_`.
""",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
""",
since = "1.4.0",
group = "math_funcs")
case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") {
override protected def withNewChildInternal(newChild: Expression): Tan = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the cotangent of `expr`, as if computed by `1/java.lang.Math.tan`.
""",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(1);
0.6420926159343306
""",
since = "2.3.0",
group = "math_funcs")
case class Cot(child: Expression)
extends UnaryMathExpression((x: Double) => 1 / math.tan(x), "COT") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.tan($c);")
}
override protected def withNewChildInternal(newChild: Expression): Cot = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the hyperbolic tangent of `expr`, as if computed by
`java.lang.Math._FUNC_`.
""",
arguments = """
Arguments:
* expr - hyperbolic angle
""",
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
""",
since = "1.4.0",
group = "math_funcs")
case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") {
override protected def withNewChildInternal(newChild: Expression): Tanh = copy(child = newChild)
}
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns inverse hyperbolic tangent of `expr`.
""",
examples = """
Examples:
> SELECT _FUNC_(0);
0.0
> SELECT _FUNC_(2);
NaN
""",
since = "3.0.0",
group = "math_funcs")
case class Atanh(child: Expression)
// SPARK-28519: more accurate express for 1/2 * ln((1 + x) / (1 - x))
extends UnaryMathExpression((x: Double) =>
0.5 * (StrictMath.log1p(x) - StrictMath.log1p(-x)), "ATANH") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev,
c => s"0.5 * (java.lang.StrictMath.log1p($c) - java.lang.StrictMath.log1p(- $c))")
}
override protected def withNewChildInternal(newChild: Expression): Atanh = copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Converts radians to degrees.",
arguments = """
Arguments:
* expr - angle in radians
""",
examples = """
Examples:
> SELECT _FUNC_(3.141592653589793);
180.0
""",
since = "1.4.0",
group = "math_funcs")
case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
override def funcName: String = "toDegrees"
override protected def withNewChildInternal(newChild: Expression): ToDegrees =
copy(child = newChild)
}
@ExpressionDescription(
usage = "_FUNC_(expr) - Converts degrees to radians.",
arguments = """
Arguments:
* expr - angle in degrees
""",
examples = """
Examples:
> SELECT _FUNC_(180);
3.141592653589793
""",
since = "1.4.0",
group = "math_funcs")
case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
override def funcName: String = "toRadians"
override protected def withNewChildInternal(newChild: Expression): ToRadians =
copy(child = newChild)
}
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the string representation of the long value `expr` represented in binary.",
examples = """
Examples:
> SELECT _FUNC_(13);
1101
> SELECT _FUNC_(-13);
1111111111111111111111111111111111111111111111111111111111110011
> SELECT _FUNC_(13.3);
1101
""",
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Bin(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant with Serializable {
override def inputTypes: Seq[DataType] = Seq(LongType)
override def dataType: DataType = StringType
protected override def nullSafeEval(input: Any): Any =
UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long]))
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c) =>
s"UTF8String.fromString(java.lang.Long.toBinaryString($c))")
}
override protected def withNewChildInternal(newChild: Expression): Bin = copy(child = newChild)
}
object Hex {
val hexDigits = Array[Char](
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
).map(_.toByte)
// lookup table to translate '0' -> 0 ... 'F'/'f' -> 15
val unhexDigits = {
val array = Array.fill[Byte](128)(-1)
(0 to 9).foreach(i => array('0' + i) = i.toByte)
(0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
(0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
array
}
def hex(bytes: Array[Byte]): UTF8String = {
val length = bytes.length
val value = new Array[Byte](length * 2)
var i = 0
while (i < length) {
value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4)
value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F)
i += 1
}
UTF8String.fromBytes(value)
}
def hex(num: Long): UTF8String = {
// Extract the hex digits of num into value[] from right to left
val value = new Array[Byte](16)
var numBuf = num
var len = 0
do {
len += 1
value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt)
numBuf >>>= 4
} while (numBuf != 0)
UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length))
}
def unhex(bytes: Array[Byte]): Array[Byte] = {
val out = new Array[Byte]((bytes.length + 1) >> 1)
var i = 0
if ((bytes.length & 0x01) != 0) {
// padding with '0'
if (bytes(0) < 0) {
return null
}
val v = Hex.unhexDigits(bytes(0))
if (v == -1) {
return null
}
out(0) = v
i += 1
}
// two characters form the hex value.
while (i < bytes.length) {
if (bytes(i) < 0 || bytes(i + 1) < 0) {
return null
}
val first = Hex.unhexDigits(bytes(i))
val second = Hex.unhexDigits(bytes(i + 1))
if (first == -1 || second == -1) {
return null
}
out(i / 2) = (((first << 4) | second) & 0xFF).toByte
i += 2
}
out
}
}
/**
* If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format.
* Otherwise if the number is a STRING, it converts each character into its hex representation
* and returns the resulting STRING. Negative numbers would be treated as two's complement.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Converts `expr` to hexadecimal.",
examples = """
Examples:
> SELECT _FUNC_(17);
11
> SELECT _FUNC_('Spark SQL');
537061726B2053514C
""",
since = "1.5.0",
group = "math_funcs")
case class Hex(child: Expression)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, BinaryType, StringType))
override def dataType: DataType = StringType
protected override def nullSafeEval(num: Any): Any = child.dataType match {
case LongType => Hex.hex(num.asInstanceOf[Long])
case BinaryType => Hex.hex(num.asInstanceOf[Array[Byte]])
case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes)
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (c) => {
val hex = Hex.getClass.getName.stripSuffix("$")
s"${ev.value} = " + (child.dataType match {
case StringType => s"""$hex.hex($c.getBytes());"""
case _ => s"""$hex.hex($c);"""
})
})
}
override protected def withNewChildInternal(newChild: Expression): Hex = copy(child = newChild)
}
/**
* Performs the inverse operation of HEX.
* Resulting characters are returned as a byte array.
*/
@ExpressionDescription(
usage = "_FUNC_(expr) - Converts hexadecimal `expr` to binary.",
examples = """
Examples:
> SELECT decode(_FUNC_('537061726B2053514C'), 'UTF-8');
Spark SQL
""",
since = "1.5.0",
group = "math_funcs")
case class Unhex(child: Expression, failOnError: Boolean = false)
extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(expr: Expression) = this(expr, false)
override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
override def nullable: Boolean = true
override def dataType: DataType = BinaryType
protected override def nullSafeEval(num: Any): Any = {
val result = Hex.unhex(num.asInstanceOf[UTF8String].getBytes)
if (failOnError && result == null) {
// The failOnError is set only from `ToBinary` function - hence we might safely set `hint`
// parameter to `try_to_binary`.
throw QueryExecutionErrors.invalidInputInConversionError(
BinaryType,
num.asInstanceOf[UTF8String],
UTF8String.fromString("HEX"),
"try_to_binary")
}
result
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val hex = Hex.getClass.getName.stripSuffix("$")
val maybeFailOnErrorCode = if (failOnError) {
val format = UTF8String.fromString("BASE64");
val binaryType = ctx.addReferenceObj("to", BinaryType, BinaryType.getClass.getName)
s"""
|if (${ev.value} == null) {
| throw QueryExecutionErrors.invalidInputInConversionError(
| $binaryType,
| $c,
| $format,
| "try_to_binary");
|}
|""".stripMargin
} else {
s"${ev.isNull} = ${ev.value} == null;"
}
s"""
${ev.value} = $hex.unhex($c.getBytes());
$maybeFailOnErrorCode
"""
})
}
override protected def withNewChildInternal(newChild: Expression): Unhex =
copy(child = newChild, failOnError)
}
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
// Binary math functions
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
@ExpressionDescription(
usage = """
_FUNC_(exprY, exprX) - Returns the angle in radians between the positive x-axis of a plane
and the point given by the coordinates (`exprX`, `exprY`), as if computed by
`java.lang.Math._FUNC_`.
""",
arguments = """
Arguments:
* exprY - coordinate on y-axis
* exprX - coordinate on x-axis
""",
examples = """
Examples:
> SELECT _FUNC_(0, 0);
0.0
""",
since = "1.4.0",
group = "math_funcs")
case class Atan2(left: Expression, right: Expression)
extends BinaryMathExpression(math.atan2, "ATAN2") {
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
// With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)")
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight)
}
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Raises `expr1` to the power of `expr2`.",
examples = """
Examples:
> SELECT _FUNC_(2, 3);
8.0
""",
since = "1.4.0",
group = "math_funcs")
case class Pow(left: Expression, right: Expression)
extends BinaryMathExpression(StrictMath.pow, "POWER") {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.StrictMath.pow($c1, $c2)")
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Expression = copy(left = newLeft, right = newRight)
}
/**
* Bitwise left shift.
*
* @param left the base number to shift.
* @param right number of bits to left shift.
*/
@ExpressionDescription(
usage = "_FUNC_(base, expr) - Bitwise left shift.",
examples = """
Examples:
> SELECT _FUNC_(2, 1);
4
""",
since = "1.5.0",
group = "math_funcs")
case class ShiftLeft(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
override def dataType: DataType = left.dataType
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
input1 match {
case l: jl.Long => l << input2.asInstanceOf[jl.Integer]
case i: jl.Integer => i << input2.asInstanceOf[jl.Integer]
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (left, right) => s"$left << $right")
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): ShiftLeft = copy(left = newLeft, right = newRight)
}
/**
* Bitwise (signed) right shift.
*
* @param left the base number to shift.
* @param right number of bits to right shift.
*/
@ExpressionDescription(
usage = "_FUNC_(base, expr) - Bitwise (signed) right shift.",
examples = """
Examples:
> SELECT _FUNC_(4, 1);
2
""",
since = "1.5.0",
group = "bitwise_funcs")
case class ShiftRight(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
override def dataType: DataType = left.dataType
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
input1 match {
case l: jl.Long => l >> input2.asInstanceOf[jl.Integer]
case i: jl.Integer => i >> input2.asInstanceOf[jl.Integer]
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (left, right) => s"$left >> $right")
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): ShiftRight = copy(left = newLeft, right = newRight)
}
/**
* Bitwise unsigned right shift, for integer and long data type.
*
* @param left the base number.
* @param right the number of bits to right shift.
*/
@ExpressionDescription(
usage = "_FUNC_(base, expr) - Bitwise unsigned right shift.",
examples = """
Examples:
> SELECT _FUNC_(4, 1);
2
""",
since = "1.5.0",
group = "bitwise_funcs")
case class ShiftRightUnsigned(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(IntegerType, LongType), IntegerType)
override def dataType: DataType = left.dataType
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
input1 match {
case l: jl.Long => l >>> input2.asInstanceOf[jl.Integer]
case i: jl.Integer => i >>> input2.asInstanceOf[jl.Integer]
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right")
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): ShiftRightUnsigned =
copy(left = newLeft, right = newRight)
}
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2) - Returns sqrt(`expr1`**2 + `expr2`**2).",
examples = """
Examples:
> SELECT _FUNC_(3, 4);
5.0
""",
since = "1.4.0",
group = "math_funcs")
case class Hypot(left: Expression, right: Expression)
extends BinaryMathExpression(math.hypot, "HYPOT") {
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Hypot =
copy(left = newLeft, right = newRight)
}
/**
* Computes the logarithm of a number.
*
* @param left the logarithm base, default to e.
* @param right the number to compute the logarithm of.
*/
@ExpressionDescription(
usage = "_FUNC_(base, expr) - Returns the logarithm of `expr` with `base`.",
examples = """
Examples:
> SELECT _FUNC_(10, 100);
2.0
""",
since = "1.5.0",
group = "math_funcs")
case class Logarithm(left: Expression, right: Expression)
extends BinaryMathExpression((c1, c2) => StrictMath.log(c2) / StrictMath.log(c1), "LOG") {
/**
* Natural log, i.e. using e as the base.
*/
def this(child: Expression) = {
this(EulerNumber(), child)
}
override def nullable: Boolean = true
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val dLeft = input1.asInstanceOf[Double]
val dRight = input2.asInstanceOf[Double]
// Unlike Hive, we support Log base in (0.0, 1.0]
if (dLeft <= 0.0 || dRight <= 0.0) null else StrictMath.log(dRight) / StrictMath.log(dLeft)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (left.isInstanceOf[EulerNumber]) {
nullSafeCodeGen(ctx, ev, (c1, c2) =>
s"""
if ($c2 <= 0.0) {
${ev.isNull} = true;
} else {
${ev.value} = java.lang.StrictMath.log($c2);
}
""")
} else {
nullSafeCodeGen(ctx, ev, (c1, c2) =>
s"""
if ($c1 <= 0.0 || $c2 <= 0.0) {
${ev.isNull} = true;
} else {
${ev.value} = java.lang.StrictMath.log($c2) / java.lang.StrictMath.log($c1);
}
""")
}
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Logarithm = copy(left = newLeft, right = newRight)
}
/**
* Round the `child`'s result to `scale` decimal place when `scale` >= 0
* or round at integral part when `scale` < 0.
*
* Child of IntegralType would round to itself when `scale` >= 0.
* Child of FractionalType whose value is NaN or Infinite would always round to itself.
*
* Round's dataType would always equal to `child`'s dataType except for DecimalType,
* which would lead scale decrease from the origin DecimalType.
*
* @param child expr to be round, all [[NumericType]] is allowed as Input
* @param scale new scale to be round to, this should be a constant int at runtime
* @param mode rounding mode (e.g. HALF_UP, HALF_EVEN)
* @param modeStr rounding mode string name (e.g. "ROUND_HALF_UP", "ROUND_HALF_EVEN")
*/
abstract class RoundBase(child: Expression, scale: Expression,
mode: BigDecimal.RoundingMode.Value, modeStr: String)
extends BinaryExpression with Serializable with ImplicitCastInputTypes {
override def left: Expression = child
override def right: Expression = scale
// round of Decimal would eval to null if it fails to `changePrecision`
override def nullable: Boolean = true
override def foldable: Boolean = child.foldable
override lazy val dataType: DataType = child.dataType match {
case DecimalType.Fixed(p, s) =>
// After rounding we may need one more digit in the integral part,
// e.g. `ceil(9.9, 0)` -> `10`, `ceil(99, -1)` -> `100`.
val integralLeastNumDigits = p - s + 1
if (_scale < 0) {
// negative scale means we need to adjust `-scale` number of digits before the decimal
// point, which means we need at lease `-scale + 1` digits (after rounding).
val newPrecision = math.max(integralLeastNumDigits, -_scale + 1)
// We have to accept the risk of overflow as we can't exceed the max precision.
DecimalType(math.min(newPrecision, DecimalType.MAX_PRECISION), 0)
} else {
val newScale = math.min(s, _scale)
// We have to accept the risk of overflow as we can't exceed the max precision.
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
case t => t
}
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case TypeCheckSuccess =>
if (scale.foldable) {
TypeCheckSuccess
} else {
TypeCheckFailure("Only foldable Expression is allowed for scale arguments")
}
case f => f
}
}
// Avoid repeated evaluation since `scale` is a constant int,
// avoid unnecessary `child` evaluation in both codegen and non-codegen eval
// by checking if scaleV == null as well.
private lazy val scaleV: Any = scale.eval(EmptyRow)
protected lazy val _scale: Int = scaleV.asInstanceOf[Int]
override def eval(input: InternalRow): Any = {
if (scaleV == null) { // if scale is null, no need to eval its child at all
null
} else {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
nullSafeEval(evalE)
}
}
}
// not overriding since _scale is a constant int at runtime
def nullSafeEval(input1: Any): Any = {
dataType match {
case DecimalType.Fixed(p, s) =>
val decimal = input1.asInstanceOf[Decimal]
if (_scale >= 0) {
// Overflow cannot happen, so no need to control nullOnOverflow
decimal.toPrecision(decimal.precision, s, mode)
} else {
Decimal(decimal.toBigDecimal.setScale(_scale, mode), p, s)
}
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort
case IntegerType =>
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt
case LongType =>
BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, mode).toLong
case FloatType =>
val f = input1.asInstanceOf[Float]
if (f.isNaN || f.isInfinite) {
f
} else {
BigDecimal(f.toDouble).setScale(_scale, mode).toFloat
}
case DoubleType =>
val d = input1.asInstanceOf[Double]
if (d.isNaN || d.isInfinite) {
d
} else {
BigDecimal(d).setScale(_scale, mode).toDouble
}
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val ce = child.genCode(ctx)
val evaluationCode = dataType match {
case DecimalType.Fixed(p, s) =>
if (_scale >= 0) {
s"""
${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s,
Decimal.$modeStr(), true, null);
${ev.isNull} = ${ev.value} == null;"""
} else {
s"""
${ev.value} = new Decimal().set(${ce.value}.toBigDecimal()
.setScale(${_scale}, Decimal.$modeStr()), $p, $s);
${ev.isNull} = ${ev.value} == null;"""
}
case ByteType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.${modeStr}).byteValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
case ShortType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.${modeStr}).shortValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
case IntegerType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
case LongType =>
if (_scale < 0) {
s"""
${ev.value} = new java.math.BigDecimal(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.${modeStr}).longValue();"""
} else {
s"${ev.value} = ${ce.value};"
}
case FloatType => // if child eval to NaN or Infinity, just return it.
s"""
if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) {
${ev.value} = ${ce.value};
} else {
${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.${modeStr}).floatValue();
}"""
case DoubleType => // if child eval to NaN or Infinity, just return it.
s"""
if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) {
${ev.value} = ${ce.value};
} else {
${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
setScale(${_scale}, java.math.BigDecimal.${modeStr}).doubleValue();
}"""
}
val javaType = CodeGenerator.javaType(dataType)
if (scaleV == null) { // if scale is null, no need to eval its child at all
ev.copy(code = code"""
boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""")
} else {
ev.copy(code = code"""
${ce.code}
boolean ${ev.isNull} = ${ce.isNull};
$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
$evaluationCode
}""")
}
}
}
/**
* Round an expression to d decimal places using HALF_UP rounding mode.
* round(2.5) == 3.0, round(3.5) == 4.0.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr, d) - Returns `expr` rounded to `d` decimal places using HALF_UP rounding mode.",
examples = """
Examples:
> SELECT _FUNC_(2.5, 0);
3
> SELECT _FUNC_(25, -1);
30
""",
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Round(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP") {
def this(child: Expression) = this(child, Literal(0))
override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Round =
copy(child = newLeft, scale = newRight)
}
/**
* Round an expression to d decimal places using HALF_EVEN rounding mode,
* also known as Gaussian rounding or bankers' rounding.
* round(2.5) = 2.0, round(3.5) = 4.0.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(expr, d) - Returns `expr` rounded to `d` decimal places using HALF_EVEN rounding mode.",
examples = """
Examples:
> SELECT _FUNC_(2.5, 0);
2
> SELECT _FUNC_(25, -1);
20
""",
since = "2.0.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class BRound(child: Expression, scale: Expression)
extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN") {
def this(child: Expression) = this(child, Literal(0))
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): BRound = copy(child = newLeft, scale = newRight)
}
object WidthBucket {
def computeBucketNumber(value: Double, min: Double, max: Double, numBucket: Long): jl.Long = {
if (isNull(value, min, max, numBucket)) {
null
} else {
computeBucketNumberNotNull(value, min, max, numBucket)
}
}
/** This function is called by generated Java code, so it needs to be public. */
def isNull(value: Double, min: Double, max: Double, numBucket: Long): Boolean = {
numBucket <= 0 ||
numBucket == Long.MaxValue ||
jl.Double.isNaN(value) ||
min == max ||
jl.Double.isNaN(min) || jl.Double.isInfinite(min) ||
jl.Double.isNaN(max) || jl.Double.isInfinite(max)
}
/** This function is called by generated Java code, so it needs to be public. */
def computeBucketNumberNotNull(
value: Double, min: Double, max: Double, numBucket: Long): jl.Long = {
val lower = Math.min(min, max)
val upper = Math.max(min, max)
if (min < max) {
if (value < lower) {
0L
} else if (value >= upper) {
numBucket + 1L
} else {
(numBucket.toDouble * (value - lower) / (upper - lower)).toLong + 1L
}
} else { // `min > max` case
if (value > upper) {
0L
} else if (value <= lower) {
numBucket + 1L
} else {
(numBucket.toDouble * (upper - value) / (upper - lower)).toLong + 1L
}
}
}
}
/**
* Returns the bucket number into which the value of this expression would fall
* after being evaluated. Note that input arguments must follow conditions listed below;
* otherwise, the method will return null.
* - `numBucket` must be greater than zero and be less than Long.MaxValue
* - `value`, `min`, and `max` cannot be NaN
* - `min` bound cannot equal `max`
* - `min` and `max` must be finite
*
* Note: If `minValue` > `maxValue`, a return value is as follows;
* if `value` > `minValue`, it returns 0.
* if `value` <= `maxValue`, it returns `numBucket` + 1.
* otherwise, it returns (`numBucket` * (`minValue` - `value`) / (`minValue` - `maxValue`)) + 1
*
* @param value is the expression to compute a bucket number in the histogram
* @param minValue is the minimum value of the histogram
* @param maxValue is the maximum value of the histogram
* @param numBucket is the number of buckets
*/
@ExpressionDescription(
usage = """
_FUNC_(value, min_value, max_value, num_bucket) - Returns the bucket number to which
`value` would be assigned in an equiwidth histogram with `num_bucket` buckets,
in the range `min_value` to `max_value`."
""",
examples = """
Examples:
> SELECT _FUNC_(5.3, 0.2, 10.6, 5);
3
> SELECT _FUNC_(-2.1, 1.3, 3.4, 3);
0
> SELECT _FUNC_(8.1, 0.0, 5.7, 4);
5
> SELECT _FUNC_(-0.9, 5.2, 0.5, 2);
3
> SELECT _FUNC_(INTERVAL '0' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10);
1
> SELECT _FUNC_(INTERVAL '1' YEAR, INTERVAL '0' YEAR, INTERVAL '10' YEAR, 10);
2
> SELECT _FUNC_(INTERVAL '0' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10);
1
> SELECT _FUNC_(INTERVAL '1' DAY, INTERVAL '0' DAY, INTERVAL '10' DAY, 10);
2
""",
since = "3.1.0",
group = "math_funcs")
case class WidthBucket(
value: Expression,
minValue: Expression,
maxValue: Expression,
numBucket: Expression)
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(
TypeCollection(DoubleType, YearMonthIntervalType, DayTimeIntervalType),
TypeCollection(DoubleType, YearMonthIntervalType, DayTimeIntervalType),
TypeCollection(DoubleType, YearMonthIntervalType, DayTimeIntervalType),
LongType)
override def checkInputDataTypes(): TypeCheckResult = {
super.checkInputDataTypes() match {
case TypeCheckSuccess =>
(value.dataType, minValue.dataType, maxValue.dataType) match {
case (_: YearMonthIntervalType, _: YearMonthIntervalType, _: YearMonthIntervalType) =>
TypeCheckSuccess
case (_: DayTimeIntervalType, _: DayTimeIntervalType, _: DayTimeIntervalType) =>
TypeCheckSuccess
case _ =>
val types = Seq(value.dataType, minValue.dataType, maxValue.dataType)
TypeUtils.checkForSameTypeInputExpr(types, s"function $prettyName")
}
case f => f
}
}
override def dataType: DataType = LongType
override def nullable: Boolean = true
override def prettyName: String = "width_bucket"
override protected def nullSafeEval(input: Any, min: Any, max: Any, numBucket: Any): Any = {
WidthBucket.computeBucketNumber(
input.asInstanceOf[Number].doubleValue(),
min.asInstanceOf[Number].doubleValue(),
max.asInstanceOf[Number].doubleValue(),
numBucket.asInstanceOf[Long])
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (input, min, max, numBucket) => {
s"""${ev.isNull} = org.apache.spark.sql.catalyst.expressions.WidthBucket
| .isNull($input, $min, $max, $numBucket);
|if (!${ev.isNull}) {
| ${ev.value} = org.apache.spark.sql.catalyst.expressions.WidthBucket
| .computeBucketNumberNotNull($input, $min, $max, $numBucket);
|}""".stripMargin
})
}
override def first: Expression = value
override def second: Expression = minValue
override def third: Expression = maxValue
override def fourth: Expression = numBucket
override protected def withNewChildrenInternal(
first: Expression, second: Expression, third: Expression, fourth: Expression): WidthBucket =
copy(value = first, minValue = second, maxValue = third, numBucket = fourth)
}
相关信息
相关文章
spark ApplyFunctionExpression 源码
spark BloomFilterMightContain 源码
spark CallMethodViaReflection 源码
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
7、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦