Skip to content

Commit f125820

Browse files
authored
SPARKC-695: Fix projection collapse on CassandraDirectJoinStrategy (#1353)
1 parent 0bd81a1 commit f125820

File tree

2 files changed

+58
-14
lines changed

2 files changed

+58
-14
lines changed

connector/src/it/scala/org/apache/spark/sql/cassandra/execution/CassandraDirectJoinSpec.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,17 @@ class CassandraDirectJoinSpec extends SparkCassandraITFlatSpecBase with DefaultC
9696
| city: 'New Orleans',
9797
| residents:{('sundance', 'dog'), ('cara', 'dog')}
9898
|})""".stripMargin)
99+
session.execute(s"CREATE TYPE $ks.user (address frozen <address>) ")
100+
session.execute(s"CREATE TABLE $ks.members (id text, user frozen <user>, PRIMARY KEY (id))")
101+
session.execute(
102+
s"""INSERT INTO $ks.members (id, user) VALUES ('test1',
103+
|{
104+
| address: {
105+
| street: 'Laurel',
106+
| city: 'New Orleans',
107+
| residents:{('sundance', 'dog'), ('cara', 'dog')}
108+
| }
109+
|})""".stripMargin)
99110
},
100111
Future {
101112
info("Making table with all PV4 Datatypes")
@@ -622,6 +633,22 @@ class CassandraDirectJoinSpec extends SparkCassandraITFlatSpecBase with DefaultC
622633
left.join(right, left("id") === right("id"))
623634
}
624635

636+
it should "work with field extractor after join" in compareDirectOnDirectOff{ spark =>
637+
val left = spark.createDataset(Seq(IdRow("test")))
638+
val right = spark.read.cassandraFormat("location", ks).load()
639+
left.join(right, left("id") === right("id"))
640+
.select($"address.*")
641+
.select($"street", $"city")
642+
}
643+
644+
it should "work with deeply nested field extractor after join" in compareDirectOnDirectOff{ spark =>
645+
val left = spark.createDataset(Seq(IdRow("test1")))
646+
val right = spark.read.cassandraFormat("members", ks).load()
647+
left.join(right, left("id") === right("id"))
648+
.select($"user.address.*")
649+
.select($"street", $"city")
650+
}
651+
625652
it should "work on a timestamp PK join" in compareDirectOnDirectOff { spark =>
626653
val left = spark.createDataset(
627654
(1 to 100).map(value => TimestampRow(new Timestamp(value.toLong)))

connector/src/main/scala/org/apache/spark/sql/cassandra/execution/CassandraDirectJoinStrategy.scala

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import com.datastax.spark.connector.util.Logging
55
import org.apache.spark.sql.{SparkSession, Strategy}
66
import org.apache.spark.sql.cassandra.{AlwaysOff, AlwaysOn, Automatic, CassandraSourceRelation}
77
import org.apache.spark.sql.cassandra.CassandraSourceRelation._
8-
import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExprId, Expression, NamedExpression}
8+
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, ExprId, Expression, NamedExpression}
99
import org.apache.spark.sql.catalyst.planning.{ExtractEquiJoinKeys, PhysicalOperation}
1010
import org.apache.spark.sql.catalyst.plans.logical._
1111
import org.apache.spark.sql.catalyst.plans._
@@ -59,7 +59,7 @@ case class CassandraDirectJoinStrategy(spark: SparkSession) extends Strategy wit
5959
cassandraScanExec
6060
)
6161

62-
val newPlan = reorderPlan(dataSourceOptimizedPlan, directJoin) :: Nil
62+
val newPlan = reorderPlan(dataSourceOptimizedPlan, directJoin, plan.output) :: Nil
6363
val newOutput = (newPlan.head.outputSet, newPlan.head.output.map(_.name))
6464
val oldOutput = (plan.outputSet, plan.output.map(_.name))
6565
val noMissingOutput = oldOutput._1.subsetOf(newPlan.head.outputSet)
@@ -232,7 +232,10 @@ object CassandraDirectJoinStrategy extends Logging {
232232
*
233233
* This should only be called on optimized Physical Plans
234234
*/
235-
def reorderPlan(plan: SparkPlan, directJoin: CassandraDirectJoinExec): SparkPlan = {
235+
def reorderPlan(
236+
plan: SparkPlan,
237+
directJoin: CassandraDirectJoinExec,
238+
originalOutput: Seq[Attribute]): SparkPlan = {
236239
val reordered = plan match {
237240
//This may be the only node in the Plan
238241
case BatchScanExec(_, _: CassandraScan, _) => directJoin
@@ -252,19 +255,25 @@ object CassandraDirectJoinStrategy extends Logging {
252255
*/
253256
reordered.transform {
254257
case ProjectExec(projectList, child) =>
258+
val attrMap = directJoin.output.map {
259+
case attr => attr.exprId -> attr
260+
}.toMap
261+
255262
val aliases = projectList.collect {
256-
case a @ Alias(child: AttributeReference, _) => (child.toAttribute.exprId, a)
263+
case a @ Alias(child, _) =>
264+
val newAliasChild = child.transform {
265+
case attr: Attribute => attrMap.getOrElse(attr.exprId, attr)
266+
}
267+
(a.exprId, a.withNewChildren(newAliasChild :: Nil).asInstanceOf[Alias])
257268
}.toMap
258269

259-
val aliasedOutput = directJoin.output.map {
260-
case attr if aliases.contains(attr.exprId) =>
261-
val oldAlias = aliases(attr.exprId)
262-
oldAlias.copy(child = attr)(oldAlias.exprId, oldAlias.qualifier,
263-
oldAlias.explicitMetadata, oldAlias.nonInheritableMetadataKeys)
270+
// The original output of Join
271+
val reorderedOutput = originalOutput.map {
272+
case attr if aliases.contains(attr.exprId) => aliases(attr.exprId)
264273
case other => other
265274
}
266275

267-
ProjectExec(aliasedOutput, child)
276+
ProjectExec(reorderedOutput, child)
268277
}
269278
}
270279

@@ -310,13 +319,21 @@ object CassandraDirectJoinStrategy extends Logging {
310319
case _ => false
311320
}
312321

322+
def getAlias(expr: NamedExpression): (String, ExprId) = expr match {
323+
case a @ Alias(child: AttributeReference, _) => child.name -> a.exprId
324+
case a @ Alias(child, _) =>
325+
val attrs = child.collect {
326+
case attr: AttributeReference => attr
327+
}
328+
assert(attrs.length == 1)
329+
attrs(0).name -> attrs(0).exprId
330+
case attributeReference: AttributeReference => attributeReference.name -> attributeReference.exprId
331+
}
332+
313333
/**
314334
* Map Source Cassandra Column Names to ExpressionIds referring to them
315335
*/
316-
def aliasMap(aliases: Seq[NamedExpression]): Map[String, ExprId] = aliases.map {
317-
case a @ Alias(child: AttributeReference, _) => child.name -> a.exprId
318-
case attributeReference: AttributeReference => attributeReference.name -> attributeReference.exprId
319-
}.toMap
336+
def aliasMap(aliases: Seq[NamedExpression]): Map[String, ExprId] = aliases.map(getAlias).toMap
320337

321338
/**
322339
* Checks whether a logical plan contains only Filters, Aliases

0 commit comments

Comments
 (0)