|
26 | 26 |
|
27 | 27 | if have_pyarrow: |
28 | 28 | import pyarrow as pa |
| 29 | + import pyarrow.compute as pc |
29 | 30 |
|
30 | 31 |
|
31 | 32 | @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message) |
@@ -460,6 +461,135 @@ def eval( |
460 | 461 | ) |
461 | 462 | assertDataFrameEqual(sql_result_df, expected_df) |
462 | 463 |
|
| 464 | + def test_arrow_udtf_lateral_join_disallowed(self): |
| 465 | + @arrow_udtf(returnType="x int, result int") |
| 466 | + class SimpleArrowUDTF: |
| 467 | + def eval(self, input_val: "pa.Array") -> Iterator["pa.Table"]: |
| 468 | + val = input_val[0].as_py() |
| 469 | + result_table = pa.table( |
| 470 | + { |
| 471 | + "x": pa.array([val], type=pa.int32()), |
| 472 | + "result": pa.array([val * 2], type=pa.int32()), |
| 473 | + } |
| 474 | + ) |
| 475 | + yield result_table |
| 476 | + |
| 477 | + self.spark.udtf.register("simple_arrow_udtf", SimpleArrowUDTF) |
| 478 | + |
| 479 | + test_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id int") |
| 480 | + test_df.createOrReplaceTempView("test_table") |
| 481 | + |
| 482 | + with self.assertRaisesRegex(Exception, "LATERAL_JOIN_WITH_ARROW_UDTF_UNSUPPORTED"): |
| 483 | + self.spark.sql( |
| 484 | + """ |
| 485 | + SELECT t.id, f.x, f.result |
| 486 | + FROM test_table t, LATERAL simple_arrow_udtf(t.id) f |
| 487 | + """ |
| 488 | + ) |
| 489 | + |
| 490 | + def test_arrow_udtf_lateral_join_with_table_argument_disallowed(self): |
| 491 | + @arrow_udtf(returnType="filtered_id bigint") |
| 492 | + class MixedArgsUDTF: |
| 493 | + def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]: |
| 494 | + filtered_data = input_table.filter(pc.greater(input_table["id"], 5)) |
| 495 | + result_table = pa.table({"filtered_id": filtered_data["id"]}) |
| 496 | + yield result_table |
| 497 | + |
| 498 | + self.spark.udtf.register("mixed_args_udtf", MixedArgsUDTF) |
| 499 | + |
| 500 | + test_df1 = self.spark.createDataFrame([(1,), (2,), (3,)], "id int") |
| 501 | + test_df1.createOrReplaceTempView("test_table1") |
| 502 | + |
| 503 | + test_df2 = self.spark.createDataFrame([(6,), (7,), (8,)], "id bigint") |
| 504 | + test_df2.createOrReplaceTempView("test_table2") |
| 505 | + |
| 506 | + # Table arguments create nested lateral joins where our CheckAnalysis rule doesn't trigger |
| 507 | + # because the Arrow UDTF is in the inner lateral join, not the outer one our rule checks. |
| 508 | + # So Spark's general lateral join validation catches this first with |
| 509 | + # NON_DETERMINISTIC_LATERAL_SUBQUERIES. |
| 510 | + with self.assertRaisesRegex( |
| 511 | + Exception, |
| 512 | + "UNSUPPORTED_SUBQUERY_EXPRESSION_CATEGORY.NON_DETERMINISTIC_LATERAL_SUBQUERIES", |
| 513 | + ): |
| 514 | + self.spark.sql( |
| 515 | + """ |
| 516 | + SELECT t1.id, f.filtered_id |
| 517 | + FROM test_table1 t1, LATERAL mixed_args_udtf(table(SELECT * FROM test_table2)) f |
| 518 | + """ |
| 519 | + ) |
| 520 | + |
| 521 | + def test_arrow_udtf_with_table_argument_then_lateral_join_allowed(self): |
| 522 | + @arrow_udtf(returnType="processed_id bigint") |
| 523 | + class TableArgUDTF: |
| 524 | + def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]: |
| 525 | + processed_data = pc.add(input_table["id"], 100) |
| 526 | + result_table = pa.table({"processed_id": processed_data}) |
| 527 | + yield result_table |
| 528 | + |
| 529 | + self.spark.udtf.register("table_arg_udtf", TableArgUDTF) |
| 530 | + |
| 531 | + source_df = self.spark.createDataFrame([(1,), (2,), (3,)], "id bigint") |
| 532 | + source_df.createOrReplaceTempView("source_table") |
| 533 | + |
| 534 | + join_df = self.spark.createDataFrame([("A",), ("B",), ("C",)], "label string") |
| 535 | + join_df.createOrReplaceTempView("join_table") |
| 536 | + |
| 537 | + result_df = self.spark.sql( |
| 538 | + """ |
| 539 | + SELECT f.processed_id, j.label |
| 540 | + FROM table_arg_udtf(table(SELECT * FROM source_table)) f, |
| 541 | + join_table j |
| 542 | + ORDER BY f.processed_id, j.label |
| 543 | + """ |
| 544 | + ) |
| 545 | + |
| 546 | + expected_data = [ |
| 547 | + (101, "A"), |
| 548 | + (101, "B"), |
| 549 | + (101, "C"), |
| 550 | + (102, "A"), |
| 551 | + (102, "B"), |
| 552 | + (102, "C"), |
| 553 | + (103, "A"), |
| 554 | + (103, "B"), |
| 555 | + (103, "C"), |
| 556 | + ] |
| 557 | + expected_df = self.spark.createDataFrame(expected_data, "processed_id bigint, label string") |
| 558 | + assertDataFrameEqual(result_df, expected_df) |
| 559 | + |
| 560 | + def test_arrow_udtf_table_argument_with_regular_udtf_lateral_join_allowed(self): |
| 561 | + @arrow_udtf(returnType="computed_value int") |
| 562 | + class ComputeUDTF: |
| 563 | + def eval(self, input_table: "pa.Table") -> Iterator["pa.Table"]: |
| 564 | + total = pc.sum(input_table["value"]).as_py() |
| 565 | + result_table = pa.table({"computed_value": pa.array([total], type=pa.int32())}) |
| 566 | + yield result_table |
| 567 | + |
| 568 | + from pyspark.sql.functions import udtf |
| 569 | + from pyspark.sql.types import StructType, StructField, IntegerType |
| 570 | + |
| 571 | + @udtf(returnType=StructType([StructField("multiplied", IntegerType())])) |
| 572 | + class MultiplyUDTF: |
| 573 | + def eval(self, input_val: int): |
| 574 | + yield (input_val * 3,) |
| 575 | + |
| 576 | + self.spark.udtf.register("compute_udtf", ComputeUDTF) |
| 577 | + self.spark.udtf.register("multiply_udtf", MultiplyUDTF) |
| 578 | + |
| 579 | + values_df = self.spark.createDataFrame([(10,), (20,), (30,)], "value int") |
| 580 | + values_df.createOrReplaceTempView("values_table") |
| 581 | + |
| 582 | + result_df = self.spark.sql( |
| 583 | + """ |
| 584 | + SELECT c.computed_value, m.multiplied |
| 585 | + FROM compute_udtf(table(SELECT * FROM values_table) WITH SINGLE PARTITION) c, |
| 586 | + LATERAL multiply_udtf(c.computed_value) m |
| 587 | + """ |
| 588 | + ) |
| 589 | + |
| 590 | + expected_df = self.spark.createDataFrame([(60, 180)], "computed_value int, multiplied int") |
| 591 | + assertDataFrameEqual(result_df, expected_df) |
| 592 | + |
463 | 593 |
|
464 | 594 | class ArrowUDTFTests(ArrowUDTFTestsMixin, ReusedSQLTestCase): |
465 | 595 | pass |
|
0 commit comments