forked from apache/datafusion-comet
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCometExecIterator.scala
More file actions
360 lines (313 loc) · 12.9 KB
/
CometExecIterator.scala
File metadata and controls
360 lines (313 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.comet
import java.lang.management.ManagementFactory
import org.apache.hadoop.conf.Configuration
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.comet.CometMetricNode
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized._
import org.apache.spark.util.SerializableConfiguration
import org.apache.comet.CometConf._
import org.apache.comet.Tracing.withTrace
import org.apache.comet.exceptions.CometQueryExecutionException
import org.apache.comet.parquet.CometFileKeyUnwrapper
import org.apache.comet.serde.Config.ConfigMap
import org.apache.comet.vector.NativeUtil
/**
* An iterator class used to execute Comet native query. It takes an input iterator which comes
* from Comet Scan and is expected to produce batches of Arrow Arrays. During consuming this
* iterator, it will consume input iterator and pass Arrow Arrays to Comet native engine by
* addresses. Even after the end of input iterator, this iterator still possibly continues
* executing native query as there might be blocking operators such as Sort, Aggregate. The API
* `hasNext` can be used to check if it is the end of this iterator (i.e. the native query is
* done).
*
* @param inputs
* The input iterators producing sequence of batches of Arrow Arrays.
* @param protobufQueryPlan
* The serialized bytes of Spark execution plan.
* @param numParts
* The number of partitions.
* @param partitionIndex
* The index of the partition.
* @param encryptedFilePaths
* Paths to encrypted Parquet files that need key unwrapping.
*/
class CometExecIterator(
val id: Long,
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
protobufQueryPlan: Array[Byte],
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIndex: Int,
broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]] = None,
encryptedFilePaths: Seq[String] = Seq.empty,
shuffleBlockIterators: Map[Int, CometShuffleBlockIterator] = Map.empty)
extends Iterator[ColumnarBatch]
with Logging {
private val tracingEnabled = CometConf.COMET_TRACING_ENABLED.get()
private val memoryMXBean = ManagementFactory.getMemoryMXBean
private val nativeLib = new Native()
private val nativeUtil = new NativeUtil()
private val taskAttemptId = TaskContext.get().taskAttemptId
private val taskCPUs = TaskContext.get().cpus()
private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId)
// Build a mixed array of iterators: CometShuffleBlockIterator for shuffle
// scan indices, CometBatchIterator for regular scan indices.
private val inputIterators: Array[Object] = inputs.zipWithIndex.map {
case (_, idx) if shuffleBlockIterators.contains(idx) =>
shuffleBlockIterators(idx).asInstanceOf[Object]
case (iterator, _) =>
new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object]
}.toArray
private val plan = {
val conf = SparkEnv.get.conf
val localDiskDirs = SparkEnv.get.blockManager.getLocalDiskDirs
// serialize Comet related Spark configs in protobuf format
val protobufSparkConfigs = CometExecIterator.serializeCometSQLConfs()
// Create keyUnwrapper if encryption is enabled
val keyUnwrapper = if (encryptedFilePaths.nonEmpty) {
val unwrapper = new CometFileKeyUnwrapper()
val hadoopConf: Configuration = broadcastedHadoopConfForEncryption.get.value.value
encryptedFilePaths.foreach(filePath =>
unwrapper.storeDecryptionKeyRetriever(filePath, hadoopConf))
unwrapper
} else {
null
}
val memoryConfig = CometExecIterator.getMemoryConfig(conf)
nativeLib.createPlan(
id,
inputIterators,
protobufQueryPlan,
protobufSparkConfigs,
numParts,
nativeMetrics,
metricsUpdateInterval = COMET_METRICS_UPDATE_INTERVAL.get(),
cometTaskMemoryManager,
localDiskDirs,
batchSize = COMET_BATCH_SIZE.get(),
memoryConfig.offHeapMode,
memoryConfig.memoryPoolType,
memoryConfig.memoryLimit,
memoryConfig.memoryLimitPerTask,
taskAttemptId,
taskCPUs,
keyUnwrapper)
}
private var nextBatch: Option[ColumnarBatch] = None
private var prevBatch: ColumnarBatch = null
private var currentBatch: ColumnarBatch = null
private var closed: Boolean = false
// Register a task completion listener to ensure native resources are released
// when the task is done.
TaskContext.get().addTaskCompletionListener[Unit] { _ =>
this.close()
}
private def getNextBatch: Option[ColumnarBatch] = {
assert(partitionIndex >= 0 && partitionIndex < numParts)
val ctx = TaskContext.get()
try {
val result = withTrace(
s"getNextBatch[JVM] stage=${ctx.stageId()}",
tracingEnabled, {
nativeUtil.getNextBatch(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs)
})
})
if (tracingEnabled) {
traceMemoryUsage()
}
result
} catch {
// Handle CometQueryExecutionException with JSON payload first
case e: CometQueryExecutionException =>
logError(s"Native execution for task $taskAttemptId failed", e)
throw SparkErrorConverter.convertToSparkException(e)
case e: CometNativeException =>
// it is generally considered bad practice to log and then rethrow an
// exception, but it really helps debugging to be able to see which task
// threw the exception, so we log the exception with taskAttemptId here
logError(s"Native execution for task $taskAttemptId failed", e)
val parquetError: scala.util.matching.Regex =
"""^Parquet error: (?:.*)$""".r
e.getMessage match {
case parquetError() =>
// See org.apache.spark.sql.errors.QueryExecutionErrors.failedToReadDataError
// See org.apache.parquet.hadoop.ParquetFileReader for error message.
// _LEGACY_ERROR_TEMP_2254 has no message placeholders; Spark 4 strict-checks
// parameters and raises INTERNAL_ERROR if any are passed.
throw new SparkException(
errorClass = "_LEGACY_ERROR_TEMP_2254",
messageParameters = Map.empty,
cause = new SparkException("File is not a Parquet file.", e))
case _ =>
throw e
}
case e: Throwable =>
throw e
}
}
override def hasNext: Boolean = {
if (closed) return false
if (nextBatch.isDefined) {
return true
}
// Close previous batch if any.
// This is to guarantee safety at the native side before we overwrite the buffer memory
// shared across batches in the native side.
if (prevBatch != null) {
prevBatch.close()
prevBatch = null
}
nextBatch = getNextBatch
logTrace(s"Task $taskAttemptId memory pool usage is ${cometTaskMemoryManager.getUsed} bytes")
if (nextBatch.isEmpty) {
close()
false
} else {
true
}
}
override def next(): ColumnarBatch = {
if (currentBatch != null) {
// Eagerly release Arrow Arrays in the previous batch
currentBatch.close()
currentBatch = null
}
if (nextBatch.isEmpty && !hasNext) {
throw new NoSuchElementException("No more element")
}
currentBatch = nextBatch.get
prevBatch = currentBatch
nextBatch = None
currentBatch
}
def close(): Unit = synchronized {
if (!closed) {
if (currentBatch != null) {
currentBatch.close()
currentBatch = null
}
nativeUtil.close()
shuffleBlockIterators.values.foreach(_.close())
nativeLib.releasePlan(plan)
if (tracingEnabled) {
traceMemoryUsage()
}
val memInUse = cometTaskMemoryManager.getUsed
if (memInUse != 0) {
logWarning(s"CometExecIterator closed with non-zero memory usage : $memInUse")
}
closed = true
}
}
private def traceMemoryUsage(): Unit = {
nativeLib.logMemoryUsage("jvm_heap_used", memoryMXBean.getHeapMemoryUsage.getUsed)
}
}
object CometExecIterator extends Logging {
private def cometSqlConfs: Map[String, String] =
SQLConf.get.getAllConfs.filter(_._1.startsWith(CometConf.COMET_PREFIX))
def serializeCometSQLConfs(): Array[Byte] = {
val builder = ConfigMap.newBuilder()
cometSqlConfs.foreach { case (k, v) =>
if (k.startsWith(s"${CometConf.COMET_PREFIX}.datafusion.")) {
if (CometConf.COMET_RESPECT_DATAFUSION_CONFIGS.get(SQLConf.get)) {
builder.putEntries(k, v)
}
} else {
builder.putEntries(k, v)
}
}
// Inject the resolved executor cores so the native side can use it
// for tokio runtime thread count
val executorCores = numDriverOrExecutorCores(SparkEnv.get.conf)
builder.putEntries("spark.executor.cores", executorCores.toString)
builder.build().toByteArray
}
def getMemoryConfig(conf: SparkConf): MemoryConfig = {
val numCores = numDriverOrExecutorCores(conf)
val coresPerTask = conf.get("spark.task.cpus", "1").toInt
// there are different paths for on-heap vs off-heap mode
val offHeapMode = CometSparkSessionExtensions.isOffHeapEnabled(conf)
if (offHeapMode) {
// in off-heap mode, Comet uses unified memory management to share off-heap memory with Spark
val offHeapSize = ByteUnit.MiB.toBytes(conf.getSizeAsMb("spark.memory.offHeap.size"))
val memoryFraction = CometConf.COMET_OFFHEAP_MEMORY_POOL_FRACTION.get()
val memoryLimit = (offHeapSize * memoryFraction).toLong
val memoryLimitPerTask = (memoryLimit.toDouble * coresPerTask / numCores).toLong
val memoryPoolType = COMET_OFFHEAP_MEMORY_POOL_TYPE.get()
logInfo(
s"memoryPoolType=$memoryPoolType, " +
s"offHeapSize=${toMB(offHeapSize)}, " +
s"memoryFraction=$memoryFraction, " +
s"memoryLimit=${toMB(memoryLimit)}, " +
s"memoryLimitPerTask=${toMB(memoryLimitPerTask)}")
MemoryConfig(offHeapMode, memoryPoolType = memoryPoolType, memoryLimit, memoryLimitPerTask)
} else {
// we'll use the built-in memory pool from DF, and initializes with `memory_limit`
// and `memory_fraction` below.
val memoryLimit = CometSparkSessionExtensions.getCometMemoryOverhead(conf)
// example 16GB maxMemory * 16 cores with 4 cores per task results
// in memory_limit_per_task = 16 GB * 4 / 16 = 16 GB / 4 = 4GB
val memoryLimitPerTask = (memoryLimit.toDouble * coresPerTask / numCores).toLong
val memoryPoolType = COMET_ONHEAP_MEMORY_POOL_TYPE.get()
logInfo(
s"memoryPoolType=$memoryPoolType, " +
s"memoryLimit=${toMB(memoryLimit)}, " +
s"memoryLimitPerTask=${toMB(memoryLimitPerTask)}")
MemoryConfig(offHeapMode, memoryPoolType = memoryPoolType, memoryLimit, memoryLimitPerTask)
}
}
private def numDriverOrExecutorCores(conf: SparkConf): Int = {
def convertToInt(threads: String): Int = {
if (threads == "*") Runtime.getRuntime.availableProcessors() else threads.toInt
}
// If running in local mode, get number of threads from the spark.master setting.
// See https://spark.apache.org/docs/latest/submitting-applications.html#master-urls
// for supported formats
// `local[*]` means using all available cores and `local[2]` means using 2 cores.
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
// Also handle format `local[num-worker-threads, max-failures]
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+|\*)\s*,\s*([0-9]+)\]""".r
val master = conf.get("spark.master")
master match {
case "local" => 1
case LOCAL_N_REGEX(threads) => convertToInt(threads)
case LOCAL_N_FAILURES_REGEX(threads, _) => convertToInt(threads)
case _ => conf.get("spark.executor.cores", "1").toInt
}
}
private def toMB(n: Long): String = {
s"${(n.toDouble / 1024.0 / 1024.0).toLong} MB"
}
}
case class MemoryConfig(
offHeapMode: Boolean,
memoryPoolType: String,
memoryLimit: Long,
memoryLimitPerTask: Long)