Skip to content

Commit 799ce36

Browse files
[SPARK-47836][SQL] Use doubles sketch replace the GK algorithm for approximate quantile computation, significantly improving merge performance
1 parent 0e10341 commit 799ce36

File tree

3 files changed

+34
-91
lines changed

3 files changed

+34
-91
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala

Lines changed: 27 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions.aggregate
1919

20-
import java.nio.ByteBuffer
21-
22-
import com.google.common.primitives.{Doubles, Ints, Longs}
20+
import org.apache.datasketches.memory.Memory
21+
import org.apache.datasketches.quantiles.{DoublesSketch, DoublesUnion, UpdateDoublesSketch}
2322

2423
import org.apache.spark.SparkException
2524
import org.apache.spark.sql.catalyst.InternalRow
@@ -31,10 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
3130
import org.apache.spark.sql.catalyst.trees.TernaryLike
3231
import org.apache.spark.sql.catalyst.types.PhysicalNumericType
3332
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
34-
import org.apache.spark.sql.catalyst.util.QuantileSummaries
35-
import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
3633
import org.apache.spark.sql.types._
37-
import org.apache.spark.util.ArrayImplicits._
3834

3935
/**
4036
* The ApproximatePercentile function returns the approximate percentile(s) of a column at the given
@@ -267,35 +263,41 @@ object ApproximatePercentile {
267263
// The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY
268264
val DEFAULT_PERCENTILE_ACCURACY: Int = 10000
269265

266+
//noinspection ScalaStyle
267+
def nextPowOf2(relativeError: Double): Int = {
268+
val baseK = DoublesSketch.getKFromEpsilon(relativeError, true)
269+
if (baseK == 1 || (baseK & (baseK - 1)) == 0) {
270+
baseK
271+
} else {
272+
Integer.highestOneBit(baseK) * 2
273+
}
274+
}
275+
270276
/**
271277
* PercentileDigest is a probabilistic data structure used for approximating percentiles
272-
* with limited memory. PercentileDigest is backed by [[QuantileSummaries]].
278+
* with limited memory. PercentileDigest is backed by [[DoublesSketch]].
273279
*
274-
* @param summaries underlying probabilistic data structure [[QuantileSummaries]].
280+
* @param sketch underlying probabilistic data structure [[DoublesSketch]].
275281
*/
276-
class PercentileDigest(private var summaries: QuantileSummaries) {
282+
class PercentileDigest(private var sketch: UpdateDoublesSketch) {
277283

278284
def this(relativeError: Double) = {
279-
this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true))
285+
this(DoublesSketch.builder().setK(ApproximatePercentile.nextPowOf2(relativeError)).build())
280286
}
281287

282-
private[sql] def isCompressed: Boolean = summaries.compressed
283-
284-
/** Returns compressed object of [[QuantileSummaries]] */
285-
def quantileSummaries: QuantileSummaries = {
286-
if (!isCompressed) compress()
287-
summaries
288-
}
288+
def sketchInfo: UpdateDoublesSketch = sketch
289289

290290
/** Insert an observation value into the PercentileDigest data structure. */
291291
def add(value: Double): Unit = {
292-
summaries = summaries.insert(value)
292+
sketch.update(value)
293293
}
294294

295295
/** In-place merges in another PercentileDigest. */
296296
def merge(other: PercentileDigest): Unit = {
297-
if (!isCompressed) compress()
298-
summaries = summaries.merge(other.quantileSummaries)
297+
val doublesUnion = DoublesUnion.builder().setMaxK(sketch.getK).build()
298+
doublesUnion.union(sketch)
299+
doublesUnion.union(other.sketch)
300+
sketch = doublesUnion.getResult
299301
}
300302

301303
/**
@@ -309,16 +311,7 @@ object ApproximatePercentile {
309311
* }}}
310312
*/
311313
def getPercentiles(percentages: Array[Double]): Seq[Double] = {
312-
if (!isCompressed) compress()
313-
if (summaries.count == 0 || percentages.length == 0) {
314-
Array.emptyDoubleArray.toImmutableArraySeq
315-
} else {
316-
summaries.query(percentages.toImmutableArraySeq).get
317-
}
318-
}
319-
320-
private final def compress(): Unit = {
321-
summaries = summaries.compress()
314+
sketch.getQuantiles(percentages).toSeq
322315
}
323316
}
324317

@@ -329,52 +322,14 @@ object ApproximatePercentile {
329322
*/
330323
class PercentileDigestSerializer {
331324

332-
private final def length(summaries: QuantileSummaries): Int = {
333-
// summaries.compressThreshold, summary.relativeError, summary.count
334-
Ints.BYTES + Doubles.BYTES + Longs.BYTES +
335-
// length of summary.sampled
336-
Ints.BYTES +
337-
// summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)]
338-
summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES)
339-
}
340-
341325
final def serialize(obj: PercentileDigest): Array[Byte] = {
342-
val summary = obj.quantileSummaries
343-
val buffer = ByteBuffer.wrap(new Array(length(summary)))
344-
buffer.putInt(summary.compressThreshold)
345-
buffer.putDouble(summary.relativeError)
346-
buffer.putLong(summary.count)
347-
buffer.putInt(summary.sampled.length)
348-
349-
var i = 0
350-
while (i < summary.sampled.length) {
351-
val stat = summary.sampled(i)
352-
buffer.putDouble(stat.value)
353-
buffer.putLong(stat.g)
354-
buffer.putLong(stat.delta)
355-
i += 1
356-
}
357-
buffer.array()
326+
val sketch = obj.sketchInfo
327+
sketch.toByteArray(false)
358328
}
359329

360330
final def deserialize(bytes: Array[Byte]): PercentileDigest = {
361-
val buffer = ByteBuffer.wrap(bytes)
362-
val compressThreshold = buffer.getInt()
363-
val relativeError = buffer.getDouble()
364-
val count = buffer.getLong()
365-
val sampledLength = buffer.getInt()
366-
val sampled = new Array[Stats](sampledLength)
367-
368-
var i = 0
369-
while (i < sampledLength) {
370-
val value = buffer.getDouble()
371-
val g = buffer.getLong()
372-
val delta = buffer.getLong()
373-
sampled(i) = Stats(value, g, delta)
374-
i += 1
375-
}
376-
val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true)
377-
new PercentileDigest(summary)
331+
val sketch = DoublesSketch.heapify(Memory.wrap(bytes))
332+
new PercentileDigest(sketch.asInstanceOf[UpdateDoublesSketch])
378333
}
379334
}
380335

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,13 @@ class ApproximatePercentileSuite extends SparkFunSuite {
426426
}
427427

428428
private def compareEquals(left: PercentileDigest, right: PercentileDigest): Boolean = {
429-
val leftSummary = left.quantileSummaries
430-
val rightSummary = right.quantileSummaries
431-
leftSummary.compressThreshold == rightSummary.compressThreshold &&
432-
leftSummary.relativeError == rightSummary.relativeError &&
433-
leftSummary.count == rightSummary.count &&
434-
leftSummary.sampled.sameElements(rightSummary.sampled)
429+
val leftSketch = left.sketchInfo
430+
val rightSketch = right.sketchInfo
431+
leftSketch.getK == rightSketch.getK &&
432+
leftSketch.getMaxItem == rightSketch.getMaxItem &&
433+
leftSketch.getMinItem == rightSketch.getMinItem &&
434+
leftSketch.getN == rightSketch.getN
435+
true
435436
}
436437

437438
private def assertEqual[T](left: T, right: T): Unit = {

sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ package org.apache.spark.sql
2020
import java.sql.{Date, Timestamp}
2121
import java.time.{Duration, LocalDateTime, Period}
2222

23-
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
2423
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
2524
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
2625
import org.apache.spark.sql.catalyst.util.DateTimeUtils
@@ -291,18 +290,6 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession
291290
}
292291
}
293292

294-
test("SPARK-24013: unneeded compress can cause performance issues with sorted input") {
295-
val buffer = new PercentileDigest(1.0D / ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)
296-
var compressCounts = 0
297-
(1 to 10000000).foreach { i =>
298-
buffer.add(i)
299-
if (buffer.isCompressed) compressCounts += 1
300-
}
301-
assert(compressCounts > 0)
302-
buffer.quantileSummaries
303-
assert(buffer.isCompressed)
304-
}
305-
306293
test("SPARK-32908: maximum target error in percentile_approx") {
307294
withTempView(table) {
308295
spark.read

0 commit comments

Comments
 (0)