Skip to content

Commit 1621de2

Browse files
add
1 parent 5660dba commit 1621de2

File tree

5 files changed

+169
-6
lines changed

5 files changed

+169
-6
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4086,6 +4086,13 @@
40864086
],
40874087
"sqlState" : "42K0L"
40884088
},
4089+
"LATERAL_JOIN_WITH_ARROW_UDTF_UNSUPPORTED" : {
4090+
"message" : [
4091+
"LATERAL JOIN with Arrow-optimized user-defined table functions (UDTFs) is not supported. Arrow UDTFs cannot be used on the right-hand side of a lateral join.",
4092+
"Please use a regular UDTF instead, or restructure your query to avoid the lateral join."
4093+
],
4094+
"sqlState" : "0A000"
4095+
},
40894096
"LOAD_DATA_PATH_NOT_EXISTS" : {
40904097
"message" : [
40914098
"LOAD DATA input path does not exist: <path>."

python/pyspark/sql/tests/arrow/test_arrow_udtf.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
if have_pyarrow:
2828
import pyarrow as pa
29+
import pyarrow.compute as pc
2930

3031

3132
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -460,6 +461,135 @@ def eval(
460461
)
461462
assertDataFrameEqual(sql_result_df, expected_df)
462463

464+
def test_arrow_udtf_lateral_join_disallowed(self):
465+
@arrow_udtf(returnType="x int, result int")
466+
class SimpleArrowUDTF:
467+
def eval(self, input_val: "pa.Array") -> Iterator["pa.Table"]:
468+
val = input_val[0].as_py()
469+
result_table = pa.table(
470+
{
471+
"x": pa.array([val], type=pa.int32()),
472+
"result": pa.array([val * 2], type=pa.int32()),
473+
}
474+
)
475+
yield result_table
476+
477+
self.spark.udtf.register("simple_arrow_udtf", SimpleArrowUDTF)
478+
479+
test_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
480+
test_df.createOrReplaceTempView("test_table")
481+
482+
with self.assertRaisesRegex(Exception, "LATERAL_JOIN_WITH_ARROW_UDTF_UNSUPPORTED"):
483+
self.spark.sql(
484+
"""
485+
SELECT t.id, f.x, f.result
486+
FROM test_table t, LATERAL simple_arrow_udtf(t.id) f
487+
"""
488+
)
489+
490+
def test_arrow_udtf_lateral_join_with_table_argument_disallowed(self):
491+
@arrow_udtf(returnType="filtered_id bigint")
492+
class MixedArgsUDTF:
493+
def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]:
494+
filtered_data = input_table.filter(pc.greater(input_table["id"], 5))
495+
result_table = pa.table({"filtered_id": filtered_data["id"]})
496+
yield result_table
497+
498+
self.spark.udtf.register("mixed_args_udtf", MixedArgsUDTF)
499+
500+
test_df1 = self.spark.createDataFrame([(1,), (2,), (3,)], "id int")
501+
test_df1.createOrReplaceTempView("test_table1")
502+
503+
test_df2 = self.spark.createDataFrame([(6,), (7,), (8,)], "id bigint")
504+
test_df2.createOrReplaceTempView("test_table2")
505+
506+
# Table arguments create nested lateral joins where our CheckAnalysis rule doesn't trigger
507+
# because the Arrow UDTF is in the inner lateral join, not the outer one our rule checks.
508+
# So Spark's general lateral join validation catches this first with
509+
# NON_DETERMINISTIC_LATERAL_SUBQUERIES.
510+
with self.assertRaisesRegex(
511+
Exception,
512+
"UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_DETERMINISTIC_LATERAL_SUBQUERIES",
513+
):
514+
self.spark.sql(
515+
"""
516+
SELECT t1.id, f.filtered_id
517+
FROM test_table1 t1, LATERAL mixed_args_udtf(table(SELECT * FROM test_table2)) f
518+
"""
519+
)
520+
521+
def test_arrow_udtf_with_table_argument_then_lateral_join_allowed(self):
522+
@arrow_udtf(returnType="processed_id bigint")
523+
class TableArgUDTF:
524+
def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]:
525+
processed_data = pc.add(input_table["id"], 100)
526+
result_table = pa.table({"processed_id": processed_data})
527+
yield result_table
528+
529+
self.spark.udtf.register("table_arg_udtf", TableArgUDTF)
530+
531+
source_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id bigint")
532+
source_df.createOrReplaceTempView("source_table")
533+
534+
join_df = self.spark.createDataFrame([("A",), ("B",), ("C",)], "label string")
535+
join_df.createOrReplaceTempView("join_table")
536+
537+
result_df = self.spark.sql(
538+
"""
539+
SELECT f.processed_id, j.label
540+
FROM table_arg_udtf(table(SELECT * FROM source_table)) f,
541+
join_table j
542+
ORDER BY f.processed_id, j.label
543+
"""
544+
)
545+
546+
expected_data = [
547+
(101, "A"),
548+
(101, "B"),
549+
(101, "C"),
550+
(102, "A"),
551+
(102, "B"),
552+
(102, "C"),
553+
(103, "A"),
554+
(103, "B"),
555+
(103, "C"),
556+
]
557+
expected_df = self.spark.createDataFrame(expected_data, "processed_id bigint, label string")
558+
assertDataFrameEqual(result_df, expected_df)
559+
560+
def test_arrow_udtf_table_argument_with_regular_udtf_lateral_join_allowed(self):
561+
@arrow_udtf(returnType="computed_value int")
562+
class ComputeUDTF:
563+
def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]:
564+
total = pc.sum(input_table["value"]).as_py()
565+
result_table = pa.table({"computed_value": pa.array([total], type=pa.int32())})
566+
yield result_table
567+
568+
from pyspark.sql.functions import udtf
569+
from pyspark.sql.types import StructType, StructField, IntegerType
570+
571+
@udtf(returnType=StructType([StructField("multiplied", IntegerType())]))
572+
class MultiplyUDTF:
573+
def eval(self, input_val: int):
574+
yield (input_val * 3,)
575+
576+
self.spark.udtf.register("compute_udtf", ComputeUDTF)
577+
self.spark.udtf.register("multiply_udtf", MultiplyUDTF)
578+
579+
values_df = self.spark.createDataFrame([(10,), (20,), (30,)], "value int")
580+
values_df.createOrReplaceTempView("values_table")
581+
582+
result_df = self.spark.sql(
583+
"""
584+
SELECT c.computed_value, m.multiplied
585+
FROM compute_udtf(table(SELECT * FROM values_table) WITH SINGLE PARTITION) c,
586+
LATERAL multiply_udtf(c.computed_value) m
587+
"""
588+
)
589+
590+
expected_df = self.spark.createDataFrame([(60, 180)], "computed_value int, multiplied int")
591+
assertDataFrameEqual(result_df, expected_df)
592+
463593

464594
class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase):
465595
pass

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2229,12 +2229,15 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
22292229
case _ => tvf
22302230
}
22312231

2232-
Project(
2233-
Seq(UnresolvedStar(Some(Seq(alias)))),
2234-
LateralJoin(
2235-
tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)),
2236-
LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None)
2237-
)
2232+
val lateralJoin = LateralJoin(
2233+
tableArgs.map(_._2).reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)),
2234+
LateralSubquery(SubqueryAlias(alias, tvfWithTableColumnIndexes)), Inner, None)
2235+
2236+
// Set the tag so that it can be used to differentiate lateral join added by
2237+
// TABLE argument vs added by user.
2238+
lateralJoin.setTagValue(LateralJoin.BY_TABLE_ARGUMENT, ())
2239+
2240+
Project(Seq(UnresolvedStar(Some(Seq(alias)))), lateralJoin)
22382241
}
22392242

22402243
case q: LogicalPlan =>
@@ -4249,6 +4252,8 @@ object RemoveTempResolvedColumn extends Rule[LogicalPlan] {
42494252
}
42504253
}
42514254

4255+
4256+
42524257
/**
42534258
* Rule that's used to handle `UnresolvedHaving` nodes with resolved `condition` and `child`.
42544259
* It's placed outside the main batch to avoid conflicts with other rules that resolve

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis
1919
import scala.collection.mutable
2020

2121
import org.apache.spark.{SparkException, SparkThrowable}
22+
import org.apache.spark.api.python.PythonEvalType
2223
import org.apache.spark.sql.AnalysisException
2324
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
2425
import org.apache.spark.sql.catalyst.analysis.ResolveWithCTE.checkIfSelfReferenceIsPlacedCorrectly
@@ -889,6 +890,17 @@ trait CheckAnalysis extends LookupCatalog with QueryErrorsBase with PlanToString
889890
messageParameters = Map(
890891
"invalidExprSqls" -> invalidExprSqls.mkString(", ")))
891892

893+
case j @ LateralJoin(_, right, _, _)
894+
if j.getTagValue(LateralJoin.BY_TABLE_ARGUMENT).isEmpty =>
895+
right.plan.foreach {
896+
case Generate(pyudtf: PythonUDTF, _, _, _, _, _)
897+
if pyudtf.evalType == PythonEvalType.SQL_ARROW_UDTF =>
898+
j.failAnalysis(
899+
errorClass = "LATERAL_JOIN_WITH_ARROW_UDTF_UNSUPPORTED",
900+
messageParameters = Map.empty)
901+
case _ =>
902+
}
903+
892904
case _ => // Analysis successful!
893905
}
894906
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,6 +2115,15 @@ case class LateralJoin(
21152115
}
21162116
}
21172117

2118+
2119+
object LateralJoin {
2120+
/**
2121+
* A tag to identify if a Lateral Join is added by resolving table argument.
2122+
*/
2123+
val BY_TABLE_ARGUMENT = TreeNodeTag[Unit]("by_table_argument")
2124+
}
2125+
2126+
21182127
/**
21192128
* A logical plan for as-of join.
21202129
*/

0 commit comments

Comments
 (0)