Skip to content

Commit bb9cc4b

Browse files
authored
fix: share unified memory pools across native execution contexts within a task (#3924)
1 parent 38da631 commit bb9cc4b

5 files changed

Lines changed: 61 additions & 21 deletions

File tree

docs/source/user-guide/latest/tuning.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,12 @@ The valid pool types are:
6161
- `fair_unified` (default when `spark.memory.offHeap.enabled=true` is set)
6262
- `greedy_unified`
6363

64-
The `fair_unified` pool types prevents operators from using more than an even fraction of the available memory
64+
Both pool types are shared across all native execution contexts within the same Spark task. When
65+
Comet executes a shuffle, it runs two native execution contexts concurrently (e.g. one for
66+
pre-shuffle operators and one for the shuffle writer). The shared pool ensures that the combined
67+
memory usage stays within the per-task limit.
68+
69+
The `fair_unified` pool prevents operators from using more than an even fraction of the available memory
6570
(i.e. `pool_size / num_reservations`). This pool works best when you know beforehand
6671
the query has multiple operators that will likely all need to spill. Sometimes it will cause spills even
6772
when there is sufficient memory in order to leave enough memory for other operators.

native/Cargo.lock

Lines changed: 1 addition & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

native/core/src/execution/jni_api.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ use crate::{
2626
},
2727
jvm_bridge::JVMClasses,
2828
};
29+
use std::collections::HashSet;
30+
2931
use arrow::array::{Array, RecordBatch, UInt32Array};
3032
use arrow::compute::{take, TakeOptions};
3133
use arrow::datatypes::DataType as ArrowDataType;
@@ -141,15 +143,33 @@ fn unregister_and_total(thread_id: u64, context_id: i64) -> usize {
141143
map.remove(&thread_id);
142144
return 0;
143145
}
144-
return pools.values().map(|p| p.reserved()).sum::<usize>();
146+
let mut seen = HashSet::new();
147+
return pools
148+
.values()
149+
.filter_map(|p| {
150+
let ptr = Arc::as_ptr(p) as *const ();
151+
seen.insert(ptr).then(|| p.reserved())
152+
})
153+
.sum::<usize>();
145154
}
146155
0
147156
}
148157

149158
fn total_reserved_for_thread(thread_id: u64) -> usize {
150159
let map = get_thread_memory_pools().lock();
151160
map.get(&thread_id)
152-
.map(|pools| pools.values().map(|p| p.reserved()).sum::<usize>())
161+
.map(|pools| {
162+
// Deduplicate pools that share the same underlying allocation
163+
// (e.g. task-shared pools registered by multiple execution contexts)
164+
let mut seen = HashSet::new();
165+
pools
166+
.values()
167+
.filter_map(|p| {
168+
let ptr = Arc::as_ptr(p) as *const ();
169+
seen.insert(ptr).then(|| p.reserved())
170+
})
171+
.sum::<usize>()
172+
})
153173
.unwrap_or(0)
154174
}
155175

native/core/src/execution/memory_pools/config.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ impl MemoryPoolType {
3434
pub(crate) fn is_task_shared(&self) -> bool {
3535
matches!(
3636
self,
37-
MemoryPoolType::GreedyTaskShared | MemoryPoolType::FairSpillTaskShared
37+
MemoryPoolType::GreedyTaskShared
38+
| MemoryPoolType::FairSpillTaskShared
39+
| MemoryPoolType::FairUnified
40+
| MemoryPoolType::GreedyUnified
3841
)
3942
}
4043
}

native/core/src/execution/memory_pools/mod.rs

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,36 @@ pub(crate) fn create_memory_pool(
4242
const NUM_TRACKED_CONSUMERS: usize = 10;
4343
match memory_pool_config.pool_type {
4444
MemoryPoolType::GreedyUnified => {
45-
// Set Comet memory pool for native
46-
let memory_pool =
47-
CometUnifiedMemoryPool::new(comet_task_memory_manager, task_attempt_id);
48-
Arc::new(TrackConsumersPool::new(
49-
memory_pool,
50-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
51-
))
45+
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
46+
let per_task_memory_pool =
47+
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
48+
let pool: Arc<dyn MemoryPool> = Arc::new(TrackConsumersPool::new(
49+
CometUnifiedMemoryPool::new(
50+
Arc::clone(&comet_task_memory_manager),
51+
task_attempt_id,
52+
),
53+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
54+
));
55+
PerTaskMemoryPool::new(pool)
56+
});
57+
per_task_memory_pool.num_plans += 1;
58+
Arc::clone(&per_task_memory_pool.memory_pool)
5259
}
5360
MemoryPoolType::FairUnified => {
54-
// Set Comet fair memory pool for native
55-
let memory_pool =
56-
CometFairMemoryPool::new(comet_task_memory_manager, memory_pool_config.pool_size);
57-
Arc::new(TrackConsumersPool::new(
58-
memory_pool,
59-
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
60-
))
61+
let mut memory_pool_map = TASK_SHARED_MEMORY_POOLS.lock().unwrap();
62+
let per_task_memory_pool =
63+
memory_pool_map.entry(task_attempt_id).or_insert_with(|| {
64+
let pool: Arc<dyn MemoryPool> = Arc::new(TrackConsumersPool::new(
65+
CometFairMemoryPool::new(
66+
Arc::clone(&comet_task_memory_manager),
67+
memory_pool_config.pool_size,
68+
),
69+
NonZeroUsize::new(NUM_TRACKED_CONSUMERS).unwrap(),
70+
));
71+
PerTaskMemoryPool::new(pool)
72+
});
73+
per_task_memory_pool.num_plans += 1;
74+
Arc::clone(&per_task_memory_pool.memory_pool)
6175
}
6276
MemoryPoolType::Greedy => Arc::new(TrackConsumersPool::new(
6377
GreedyMemoryPool::new(memory_pool_config.pool_size),

0 commit comments

Comments
 (0)