@@ -5,7 +5,7 @@ import com.datastax.spark.connector.util.Logging
5
5
import org .apache .spark .sql .{SparkSession , Strategy }
6
6
import org .apache .spark .sql .cassandra .{AlwaysOff , AlwaysOn , Automatic , CassandraSourceRelation }
7
7
import org .apache .spark .sql .cassandra .CassandraSourceRelation ._
8
- import org .apache .spark .sql .catalyst .expressions .{Alias , AttributeReference , ExprId , Expression , NamedExpression }
8
+ import org .apache .spark .sql .catalyst .expressions .{Alias , Attribute , AttributeReference , ExprId , Expression , NamedExpression }
9
9
import org .apache .spark .sql .catalyst .planning .{ExtractEquiJoinKeys , PhysicalOperation }
10
10
import org .apache .spark .sql .catalyst .plans .logical ._
11
11
import org .apache .spark .sql .catalyst .plans ._
@@ -59,7 +59,7 @@ case class CassandraDirectJoinStrategy(spark: SparkSession) extends Strategy wit
59
59
cassandraScanExec
60
60
)
61
61
62
- val newPlan = reorderPlan(dataSourceOptimizedPlan, directJoin) :: Nil
62
+ val newPlan = reorderPlan(dataSourceOptimizedPlan, directJoin, plan.output ) :: Nil
63
63
val newOutput = (newPlan.head.outputSet, newPlan.head.output.map(_.name))
64
64
val oldOutput = (plan.outputSet, plan.output.map(_.name))
65
65
val noMissingOutput = oldOutput._1.subsetOf(newPlan.head.outputSet)
@@ -232,7 +232,10 @@ object CassandraDirectJoinStrategy extends Logging {
232
232
*
233
233
* This should only be called on optimized Physical Plans
234
234
*/
235
- def reorderPlan (plan : SparkPlan , directJoin : CassandraDirectJoinExec ): SparkPlan = {
235
+ def reorderPlan (
236
+ plan : SparkPlan ,
237
+ directJoin : CassandraDirectJoinExec ,
238
+ originalOutput : Seq [Attribute ]): SparkPlan = {
236
239
val reordered = plan match {
237
240
// This may be the only node in the Plan
238
241
case BatchScanExec (_, _ : CassandraScan , _) => directJoin
@@ -252,19 +255,25 @@ object CassandraDirectJoinStrategy extends Logging {
252
255
*/
253
256
reordered.transform {
254
257
case ProjectExec (projectList, child) =>
258
+ val attrMap = directJoin.output.map {
259
+ case attr => attr.exprId -> attr
260
+ }.toMap
261
+
255
262
val aliases = projectList.collect {
256
- case a @ Alias (child : AttributeReference , _) => (child.toAttribute.exprId, a)
263
+ case a @ Alias (child, _) =>
264
+ val newAliasChild = child.transform {
265
+ case attr : Attribute => attrMap.getOrElse(attr.exprId, attr)
266
+ }
267
+ (a.exprId, a.withNewChildren(newAliasChild :: Nil ).asInstanceOf [Alias ])
257
268
}.toMap
258
269
259
- val aliasedOutput = directJoin.output.map {
260
- case attr if aliases.contains(attr.exprId) =>
261
- val oldAlias = aliases(attr.exprId)
262
- oldAlias.copy(child = attr)(oldAlias.exprId, oldAlias.qualifier,
263
- oldAlias.explicitMetadata, oldAlias.nonInheritableMetadataKeys)
270
+ // The original output of Join
271
+ val reorderedOutput = originalOutput.map {
272
+ case attr if aliases.contains(attr.exprId) => aliases(attr.exprId)
264
273
case other => other
265
274
}
266
275
267
- ProjectExec (aliasedOutput , child)
276
+ ProjectExec (reorderedOutput , child)
268
277
}
269
278
}
270
279
@@ -310,13 +319,21 @@ object CassandraDirectJoinStrategy extends Logging {
310
319
case _ => false
311
320
}
312
321
322
+ def getAlias (expr : NamedExpression ): (String , ExprId ) = expr match {
323
+ case a @ Alias (child : AttributeReference , _) => child.name -> a.exprId
324
+ case a @ Alias (child, _) =>
325
+ val attrs = child.collect {
326
+ case attr : AttributeReference => attr
327
+ }
328
+ assert(attrs.length == 1 )
329
+ attrs(0 ).name -> attrs(0 ).exprId
330
+ case attributeReference : AttributeReference => attributeReference.name -> attributeReference.exprId
331
+ }
332
+
313
333
/**
314
334
* Map Source Cassandra Column Names to ExpressionIds referring to them
315
335
*/
316
- def aliasMap (aliases : Seq [NamedExpression ]): Map [String , ExprId ] = aliases.map {
317
- case a @ Alias (child : AttributeReference , _) => child.name -> a.exprId
318
- case attributeReference : AttributeReference => attributeReference.name -> attributeReference.exprId
319
- }.toMap
336
+ def aliasMap (aliases : Seq [NamedExpression ]): Map [String , ExprId ] = aliases.map(getAlias).toMap
320
337
321
338
/**
322
339
* Checks whether a logical plan contains only Filters, Aliases
0 commit comments