spark ObjectAggregationMap 源码
spark ObjectAggregationMap 代码
文件路径:/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.aggregate
import java.{util => ju}
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.internal.config
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, TypedImperativeAggregate}
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.types.StructType
/**
* An aggregation map that supports using safe `SpecificInternalRow`s aggregation buffers, so that
* we can support storing arbitrary Java objects as aggregate function states in the aggregation
* buffers. This class is only used together with [[ObjectHashAggregateExec]].
*/
class ObjectAggregationMap() {
private[this] val hashMap = new ju.LinkedHashMap[UnsafeRow, InternalRow]
def getAggregationBuffer(groupingKey: UnsafeRow): InternalRow = {
hashMap.get(groupingKey)
}
def putAggregationBuffer(groupingKey: UnsafeRow, aggBuffer: InternalRow): Unit = {
hashMap.put(groupingKey, aggBuffer)
}
def size: Int = hashMap.size()
/**
* Returns a destructive iterator of AggregationBufferEntry.
* Notice: it is illegal to call any method after `destructiveIterator()` has been called.
*/
def destructiveIterator(): Iterator[AggregationBufferEntry] = {
val iter = hashMap.entrySet().iterator()
new Iterator[AggregationBufferEntry] {
override def hasNext: Boolean = {
iter.hasNext
}
override def next(): AggregationBufferEntry = {
val entry = iter.next()
iter.remove()
new AggregationBufferEntry(entry.getKey, entry.getValue)
}
}
}
/**
* Dumps all entries into a newly created external sorter, clears the hash map, and returns the
* external sorter.
*/
def dumpToExternalSorter(
groupingAttributes: Seq[Attribute],
aggregateFunctions: Seq[AggregateFunction]): UnsafeKVExternalSorter = {
val aggBufferAttributes = aggregateFunctions.flatMap(_.aggBufferAttributes)
val sorter = new UnsafeKVExternalSorter(
StructType.fromAttributes(groupingAttributes),
StructType.fromAttributes(aggBufferAttributes),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
TaskContext.get().taskMemoryManager().pageSizeBytes,
SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD),
null
)
val mapIterator = destructiveIterator()
val unsafeAggBufferProjection =
UnsafeProjection.create(aggBufferAttributes.map(_.dataType).toArray)
while (mapIterator.hasNext) {
val entry = mapIterator.next()
aggregateFunctions.foreach {
case agg: TypedImperativeAggregate[_] =>
agg.serializeAggregateBufferInPlace(entry.aggregationBuffer)
case _ =>
}
sorter.insertKV(
entry.groupingKey,
unsafeAggBufferProjection(entry.aggregationBuffer)
)
}
hashMap.clear()
sorter
}
def clear(): Unit = {
hashMap.clear()
}
}
// Stores the grouping key and aggregation buffer
class AggregationBufferEntry(var groupingKey: UnsafeRow, var aggregationBuffer: InternalRow)
相关信息
相关文章
spark AggregateCodegenSupport 源码
spark MergingSessionsIterator 源码
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
7、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦