1717
1818package  org .apache .spark .sql .catalyst .expressions .aggregate 
1919
20- import  java .nio .ByteBuffer 
21- 
22- import  com .google .common .primitives .{Doubles , Ints , Longs }
20+ import  org .apache .datasketches .memory .Memory 
21+ import  org .apache .datasketches .quantiles .{DoublesSketch , DoublesUnion , UpdateDoublesSketch }
2322
2423import  org .apache .spark .SparkException 
2524import  org .apache .spark .sql .catalyst .InternalRow 
@@ -31,10 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
3130import  org .apache .spark .sql .catalyst .trees .TernaryLike 
3231import  org .apache .spark .sql .catalyst .types .PhysicalNumericType 
3332import  org .apache .spark .sql .catalyst .util .{ArrayData , GenericArrayData }
34- import  org .apache .spark .sql .catalyst .util .QuantileSummaries 
35- import  org .apache .spark .sql .catalyst .util .QuantileSummaries .{defaultCompressThreshold , Stats }
3633import  org .apache .spark .sql .types ._ 
37- import  org .apache .spark .util .ArrayImplicits ._ 
3834
3935/** 
4036 * The ApproximatePercentile function returns the approximate percentile(s) of a column at the given 
@@ -267,35 +263,41 @@ object ApproximatePercentile {
267263  //  The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY
268264  val  DEFAULT_PERCENTILE_ACCURACY :  Int  =  10000 
269265
266+   // noinspection ScalaStyle
267+   def  nextPowOf2 (relativeError : Double ):  Int  =  {
268+     val  baseK  =  DoublesSketch .getKFromEpsilon(relativeError, true )
269+     if  (baseK ==  1  ||  (baseK &  (baseK -  1 )) ==  0 ) {
270+       baseK
271+     } else  {
272+       Integer .highestOneBit(baseK) *  2 
273+     }
274+   }
275+ 
270276  /**  
271277   * PercentileDigest is a probabilistic data structure used for approximating percentiles 
272-    * with limited memory. PercentileDigest is backed by [[QuantileSummaries ]]. 
278+    * with limited memory. PercentileDigest is backed by [[DoublesSketch ]]. 
273279   * 
274-    * @param  summaries  underlying probabilistic data structure [[QuantileSummaries ]]. 
280+    * @param  sketch  underlying probabilistic data structure [[DoublesSketch ]]. 
275281   */  
276-   class  PercentileDigest (private  var  summaries :   QuantileSummaries ) {
282+   class  PercentileDigest (private  var  sketch :   UpdateDoublesSketch ) {
277283
278284    def  this (relativeError : Double ) =  {
279-       this (new   QuantileSummaries (defaultCompressThreshold,  relativeError, compressed  =   true ))
285+       this (DoublesSketch .builder().setK( ApproximatePercentile .nextPowOf2( relativeError)).build( ))
280286    }
281287
282-     private [sql] def  isCompressed :  Boolean  =  summaries.compressed
283- 
284-     /**  Returns compressed object of [[QuantileSummaries ]] */  
285-     def  quantileSummaries :  QuantileSummaries  =  {
286-       if  (! isCompressed) compress()
287-       summaries
288-     }
288+     def  sketchInfo :  UpdateDoublesSketch  =  sketch
289289
290290    /**  Insert an observation value into the PercentileDigest data structure. */  
291291    def  add (value : Double ):  Unit  =  {
292-       summaries  =  summaries.insert (value)
292+       sketch.update (value)
293293    }
294294
295295    /**  In-place merges in another PercentileDigest. */  
296296    def  merge (other : PercentileDigest ):  Unit  =  {
297-       if  (! isCompressed) compress()
298-       summaries =  summaries.merge(other.quantileSummaries)
297+       val  doublesUnion  =  DoublesUnion .builder().setMaxK(sketch.getK).build()
298+       doublesUnion.union(sketch)
299+       doublesUnion.union(other.sketch)
300+       sketch =  doublesUnion.getResult
299301    }
300302
301303    /**  
@@ -309,16 +311,7 @@ object ApproximatePercentile {
309311     * }}} 
310312     */  
311313    def  getPercentiles (percentages : Array [Double ]):  Seq [Double ] =  {
312-       if  (! isCompressed) compress()
313-       if  (summaries.count ==  0  ||  percentages.length ==  0 ) {
314-         Array .emptyDoubleArray.toImmutableArraySeq
315-       } else  {
316-         summaries.query(percentages.toImmutableArraySeq).get
317-       }
318-     }
319- 
320-     private  final  def  compress ():  Unit  =  {
321-       summaries =  summaries.compress()
314+       sketch.getQuantiles(percentages).toSeq
322315    }
323316  }
324317
@@ -329,52 +322,14 @@ object ApproximatePercentile {
329322   */  
330323  class  PercentileDigestSerializer  {
331324
332-     private  final  def  length (summaries : QuantileSummaries ):  Int  =  {
333-       //  summaries.compressThreshold, summary.relativeError, summary.count
334-       Ints .BYTES  +  Doubles .BYTES  +  Longs .BYTES  + 
335-       //  length of summary.sampled
336-       Ints .BYTES  + 
337-       //  summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)]
338-       summaries.sampled.length *  (Doubles .BYTES  +  Longs .BYTES  +  Longs .BYTES )
339-     }
340- 
341325    final  def  serialize (obj : PercentileDigest ):  Array [Byte ] =  {
342-       val  summary  =  obj.quantileSummaries
343-       val  buffer  =  ByteBuffer .wrap(new  Array (length(summary)))
344-       buffer.putInt(summary.compressThreshold)
345-       buffer.putDouble(summary.relativeError)
346-       buffer.putLong(summary.count)
347-       buffer.putInt(summary.sampled.length)
348- 
349-       var  i  =  0 
350-       while  (i <  summary.sampled.length) {
351-         val  stat  =  summary.sampled(i)
352-         buffer.putDouble(stat.value)
353-         buffer.putLong(stat.g)
354-         buffer.putLong(stat.delta)
355-         i +=  1 
356-       }
357-       buffer.array()
326+       val  sketch  =  obj.sketchInfo
327+       sketch.toByteArray(false )
358328    }
359329
360330    final  def  deserialize (bytes : Array [Byte ]):  PercentileDigest  =  {
361-       val  buffer  =  ByteBuffer .wrap(bytes)
362-       val  compressThreshold  =  buffer.getInt()
363-       val  relativeError  =  buffer.getDouble()
364-       val  count  =  buffer.getLong()
365-       val  sampledLength  =  buffer.getInt()
366-       val  sampled  =  new  Array [Stats ](sampledLength)
367- 
368-       var  i  =  0 
369-       while  (i <  sampledLength) {
370-         val  value  =  buffer.getDouble()
371-         val  g  =  buffer.getLong()
372-         val  delta  =  buffer.getLong()
373-         sampled(i) =  Stats (value, g, delta)
374-         i +=  1 
375-       }
376-       val  summary  =  new  QuantileSummaries (compressThreshold, relativeError, sampled, count, true )
377-       new  PercentileDigest (summary)
331+       val  sketch  =  DoublesSketch .heapify(Memory .wrap(bytes))
332+       new  PercentileDigest (sketch.asInstanceOf [UpdateDoublesSketch ])
378333    }
379334  }
380335
0 commit comments