Skip to content

Commit ba9b842

Browse files
authored
Hoist some stuff out of NativeBatchDecoderIterator into CometBlockStoreShuffleReader that can be reused. (#3627)
1 parent 48e32a8 commit ba9b842

2 files changed

Lines changed: 20 additions & 18 deletions

File tree

spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometBlockStoreShuffleReader.scala

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ import org.apache.spark.storage.BlockManagerId
3535
import org.apache.spark.storage.ShuffleBlockFetcherIterator
3636
import org.apache.spark.util.CompletionIterator
3737

38+
import org.apache.comet.{CometConf, Native}
39+
import org.apache.comet.vector.NativeUtil
40+
3841
/**
3942
* Shuffle reader that reads data from the block manager. It reads Arrow-serialized data (IPC
4043
* format) and returns an iterator of ColumnarBatch.
@@ -86,24 +89,32 @@ class CometBlockStoreShuffleReader[K, C](
8689
/** Read the combined key-values for this reduce task */
8790
override def read(): Iterator[Product2[K, C]] = {
8891
var currentReadIterator: NativeBatchDecoderIterator = null
92+
val nativeLib = new Native()
93+
val nativeUtil = new NativeUtil()
94+
val tracingEnabled = CometConf.COMET_TRACING_ENABLED.get()
8995

90-
// Closes last read iterator after the task is finished.
96+
// Closes last read iterator and shared resources after the task is finished.
9197
// We need to close read iterator during iterating input streams,
9298
// instead of one callback per read iterator. Otherwise if there are too many
9399
// read iterators, it may blow up the call stack and cause OOM.
94100
context.addTaskCompletionListener[Unit] { _ =>
95101
if (currentReadIterator != null) {
96102
currentReadIterator.close()
97103
}
104+
nativeUtil.close()
98105
}
99106

100107
val recordIter: Iterator[(Int, ColumnarBatch)] = fetchIterator
101108
.flatMap(blockIdAndStream => {
102109
if (currentReadIterator != null) {
103110
currentReadIterator.close()
104111
}
105-
currentReadIterator =
106-
NativeBatchDecoderIterator(blockIdAndStream._2, context, dep.decodeTime)
112+
currentReadIterator = NativeBatchDecoderIterator(
113+
blockIdAndStream._2,
114+
dep.decodeTime,
115+
nativeLib,
116+
nativeUtil,
117+
tracingEnabled)
107118
currentReadIterator
108119
})
109120
.map(b => (0, b))

spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/NativeBatchDecoderIterator.scala

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,10 @@ import java.io.{EOFException, InputStream}
2323
import java.nio.{ByteBuffer, ByteOrder}
2424
import java.nio.channels.{Channels, ReadableByteChannel}
2525

26-
import org.apache.spark.TaskContext
2726
import org.apache.spark.sql.execution.metric.SQLMetric
2827
import org.apache.spark.sql.vectorized.ColumnarBatch
2928

30-
import org.apache.comet.{CometConf, Native}
29+
import org.apache.comet.Native
3130
import org.apache.comet.vector.NativeUtil
3231

3332
/**
@@ -37,26 +36,19 @@ import org.apache.comet.vector.NativeUtil
3736
*/
3837
case class NativeBatchDecoderIterator(
3938
in: InputStream,
40-
taskContext: TaskContext,
41-
decodeTime: SQLMetric)
39+
decodeTime: SQLMetric,
40+
nativeLib: Native,
41+
nativeUtil: NativeUtil,
42+
tracingEnabled: Boolean)
4243
extends Iterator[ColumnarBatch] {
4344

4445
private var isClosed = false
4546
private val longBuf = ByteBuffer.allocate(8).order(ByteOrder.LITTLE_ENDIAN)
46-
private val native = new Native()
47-
private val nativeUtil = new NativeUtil()
48-
private val tracingEnabled = CometConf.COMET_TRACING_ENABLED.get()
4947
private var currentBatch: ColumnarBatch = null
5048
private var batch = fetchNext()
5149

5250
import NativeBatchDecoderIterator._
5351

54-
if (taskContext != null) {
55-
taskContext.addTaskCompletionListener[Unit](_ => {
56-
close()
57-
})
58-
}
59-
6052
private val channel: ReadableByteChannel = if (in != null) {
6153
Channels.newChannel(in)
6254
} else {
@@ -163,7 +155,7 @@ case class NativeBatchDecoderIterator(
163155
val batch = nativeUtil.getNextBatch(
164156
fieldCount,
165157
(arrayAddrs, schemaAddrs) => {
166-
native.decodeShuffleBlock(
158+
nativeLib.decodeShuffleBlock(
167159
dataBuf,
168160
bytesToRead.toInt,
169161
arrayAddrs,
@@ -183,7 +175,6 @@ case class NativeBatchDecoderIterator(
183175
currentBatch = null
184176
}
185177
in.close()
186-
nativeUtil.close()
187178
resetDataBuf()
188179
isClosed = true
189180
}

0 commit comments

Comments
 (0)