Skip to content

Commit 2696951

Browse files
Dandandanjorgecarleitao
authored andcommitted
ARROW-11300: [Rust][DataFusion] Further performance improvements on hash aggregation with small groups
Based on #9234, this PR improves the situation described in https://issues.apache.org/jira/browse/ARROW-11300. The current situation is that we call `take` on arrays, which is fine, but causes a lot of small `Arrays` to be created / allocated. when we have only a small number of rows in each group. This improves the results on the group by queries on db-benchmark: PR: ``` q1 took 32 ms q2 took 422 ms q3 took 3468 ms q4 took 44 ms q5 took 3166 ms q7 took 3081 ms ``` #9234 (different results from that PR description as this has now partitioning enabled and a custom allocator) ``` q1 took 34 ms q2 took 389 ms q3 took 4590 ms q4 took 47 ms q5 took 5152 ms q7 took 3941 ms ``` The PR changes the algorithm to: * Create indices / offsets of all keys / indices new in the batch. * `take` the arrays based on indices in one go (so it only requires one bigger allocation for each array) * Use `slice` based on the offsets to take values from the arrays and pass it to the accumulators. Closes #9271 from Dandandan/hash_agg_few_rows Authored-by: Heres, Daniel <[email protected]> Signed-off-by: Jorge C. Leitao <[email protected]>
1 parent 61b0cb1 commit 2696951

File tree

1 file changed

+79
-48
lines changed

1 file changed

+79
-48
lines changed

rust/datafusion/src/physical_plan/hash_aggregate.rs

Lines changed: 79 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use std::any::Any;
2121
use std::sync::Arc;
2222
use std::task::{Context, Poll};
2323

24+
use ahash::RandomState;
2425
use futures::{
2526
stream::{Stream, StreamExt},
2627
Future,
@@ -30,11 +31,10 @@ use crate::error::{DataFusionError, Result};
3031
use crate::physical_plan::{Accumulator, AggregateExpr};
3132
use crate::physical_plan::{Distribution, ExecutionPlan, Partitioning, PhysicalExpr};
3233

33-
use arrow::error::{ArrowError, Result as ArrowResult};
34-
use arrow::record_batch::RecordBatch;
34+
use arrow::array::BooleanArray;
3535
use arrow::{
36-
array::BooleanArray,
37-
datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit},
36+
array::{Array, UInt32Builder},
37+
error::{ArrowError, Result as ArrowResult},
3838
};
3939
use arrow::{
4040
array::{
@@ -43,19 +43,22 @@ use arrow::{
4343
},
4444
compute,
4545
};
46-
use pin_project_lite::pin_project;
47-
48-
use super::{
49-
expressions::Column, group_scalar::GroupByScalar, RecordBatchStream,
50-
SendableRecordBatchStream,
46+
use arrow::{
47+
datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit},
48+
record_batch::RecordBatch,
5149
};
52-
use ahash::RandomState;
5350
use hashbrown::HashMap;
5451
use ordered_float::OrderedFloat;
52+
use pin_project_lite::pin_project;
5553

5654
use arrow::array::{TimestampMicrosecondArray, TimestampNanosecondArray};
5755
use async_trait::async_trait;
5856

57+
use super::{
58+
expressions::Column, group_scalar::GroupByScalar, RecordBatchStream,
59+
SendableRecordBatchStream,
60+
};
61+
5962
/// Hash aggregate modes
6063
#[derive(Debug, Copy, Clone)]
6164
pub enum AggregateMode {
@@ -322,48 +325,76 @@ fn group_aggregate_batch(
322325
});
323326
}
324327

328+
// Collect all indices + offsets based on keys in this vec
329+
let mut batch_indices: UInt32Builder = UInt32Builder::new(0);
330+
let mut offsets = vec![0];
331+
let mut offset_so_far = 0;
332+
for key in batch_keys.iter() {
333+
let (_, _, indices) = accumulators.get_mut(key).unwrap();
334+
batch_indices.append_slice(&indices)?;
335+
offset_so_far += indices.len();
336+
offsets.push(offset_so_far);
337+
}
338+
let batch_indices = batch_indices.finish();
339+
340+
// `Take` all values based on indices into Arrays
341+
let values: Vec<Vec<Arc<dyn Array>>> = aggr_input_values
342+
.iter()
343+
.map(|array| {
344+
array
345+
.iter()
346+
.map(|array| {
347+
compute::take(
348+
array.as_ref(),
349+
&batch_indices,
350+
None, // None: no index check
351+
)
352+
.unwrap()
353+
})
354+
.collect()
355+
// 2.3
356+
})
357+
.collect();
358+
325359
// 2.1 for each key in this batch
326360
// 2.2 for each aggregation
327-
// 2.3 `take` from each of its arrays the keys' values
361+
// 2.3 `slice` from each of its arrays the keys' values
328362
// 2.4 update / merge the accumulator with the values
329363
// 2.5 clear indices
330-
batch_keys.iter_mut().try_for_each(|key| {
331-
let (_, accumulator_set, indices) = accumulators.get_mut(key).unwrap();
332-
let primitive_indices = UInt32Array::from(indices.clone());
333-
// 2.2
334-
accumulator_set
335-
.iter_mut()
336-
.zip(&aggr_input_values)
337-
.map(|(accumulator, aggr_array)| {
338-
(
339-
accumulator,
340-
aggr_array
341-
.iter()
342-
.map(|array| {
343-
// 2.3
344-
compute::take(
345-
array.as_ref(),
346-
&primitive_indices,
347-
None, // None: no index check
348-
)
349-
.unwrap()
350-
})
351-
.collect::<Vec<ArrayRef>>(),
352-
)
353-
})
354-
.try_for_each(|(accumulator, values)| match mode {
355-
AggregateMode::Partial => accumulator.update_batch(&values),
356-
AggregateMode::Final => {
357-
// note: the aggregation here is over states, not values, thus the merge
358-
accumulator.merge_batch(&values)
359-
}
360-
})
361-
// 2.5
362-
.and({
363-
indices.clear();
364-
Ok(())
365-
})
366-
})?;
364+
batch_keys
365+
.iter_mut()
366+
.zip(offsets.windows(2))
367+
.try_for_each(|(key, offsets)| {
368+
let (_, accumulator_set, indices) = accumulators.get_mut(key).unwrap();
369+
// 2.2
370+
accumulator_set
371+
.iter_mut()
372+
.zip(values.iter())
373+
.map(|(accumulator, aggr_array)| {
374+
(
375+
accumulator,
376+
aggr_array
377+
.iter()
378+
.map(|array| {
379+
// 2.3
380+
array.slice(offsets[0], offsets[1] - offsets[0])
381+
})
382+
.collect(),
383+
)
384+
})
385+
.try_for_each(|(accumulator, values)| match mode {
386+
AggregateMode::Partial => accumulator.update_batch(&values),
387+
AggregateMode::Final => {
388+
// note: the aggregation here is over states, not values, thus the merge
389+
accumulator.merge_batch(&values)
390+
}
391+
})
392+
// 2.5
393+
.and({
394+
indices.clear();
395+
Ok(())
396+
})
397+
})?;
367398
Ok(accumulators)
368399
}
369400

0 commit comments

Comments
 (0)