spark UDTRegistration 源码
spark UDTRegistration 代码
文件路径:/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UDTRegistration.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.types
import scala.collection.mutable
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.util.Utils
/**
* This object keeps the mappings between user classes and their User Defined Types (UDTs).
* Previously we use the annotation `SQLUserDefinedType` to register UDTs for user classes.
* However, by doing this, we add SparkSQL dependency on user classes. This object provides
* alternative approach to register UDTs for user classes.
*/
@DeveloperApi
@Since("3.2.0")
object UDTRegistration extends Serializable with Logging {
/** The mapping between the Class between UserDefinedType and user classes. */
private lazy val udtMap: mutable.Map[String, String] = mutable.Map(
("org.apache.spark.ml.linalg.Vector", "org.apache.spark.ml.linalg.VectorUDT"),
("org.apache.spark.ml.linalg.DenseVector", "org.apache.spark.ml.linalg.VectorUDT"),
("org.apache.spark.ml.linalg.SparseVector", "org.apache.spark.ml.linalg.VectorUDT"),
("org.apache.spark.ml.linalg.Matrix", "org.apache.spark.ml.linalg.MatrixUDT"),
("org.apache.spark.ml.linalg.DenseMatrix", "org.apache.spark.ml.linalg.MatrixUDT"),
("org.apache.spark.ml.linalg.SparseMatrix", "org.apache.spark.ml.linalg.MatrixUDT"))
/**
* Queries if a given user class is already registered or not.
* @param userClassName the name of user class
* @return boolean value indicates if the given user class is registered or not
*/
def exists(userClassName: String): Boolean = udtMap.contains(userClassName)
/**
* Registers an UserDefinedType to an user class. If the user class is already registered
* with another UserDefinedType, warning log message will be shown.
* @param userClass the name of user class
* @param udtClass the name of UserDefinedType class for the given userClass
*/
def register(userClass: String, udtClass: String): Unit = {
if (udtMap.contains(userClass)) {
logWarning(s"Cannot register UDT for ${userClass}, which is already registered.")
} else {
// When register UDT with class name, we can't check if the UDT class is an UserDefinedType,
// or not. The check is deferred.
udtMap += ((userClass, udtClass))
}
}
/**
* Returns the Class of UserDefinedType for the name of a given user class.
* @param userClass class name of user class
* @return Option value of the Class object of UserDefinedType
*/
def getUDTFor(userClass: String): Option[Class[_]] = {
udtMap.get(userClass).map { udtClassName =>
if (Utils.classIsLoadable(udtClassName)) {
val udtClass = Utils.classForName(udtClassName)
if (classOf[UserDefinedType[_]].isAssignableFrom(udtClass)) {
udtClass
} else {
throw QueryExecutionErrors.notUserDefinedTypeError(udtClass.getName, userClass)
}
} else {
throw QueryExecutionErrors.cannotLoadUserDefinedTypeError(udtClassName, userClass)
}
}
}
}
相关信息
相关文章
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
7、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦