1- from typing import Any , Dict , Literal
1+ from typing import Any , Dict , Literal , Optional
22
33from pyspark .sql import DataFrame
44from pyspark .sql .functions import col , collect_set
5- from pyspark .sql .functions import max as spark_max
65
76
87def ingest_spark_dataframe (
98 spark_dataframe : DataFrame ,
109 save_mode : Literal ["Overwrite" , "Append" ],
1110 options : Dict [str , Any ],
11+ num_groups : Optional [int ] = None
1212) -> None :
1313 """
1414 Saves a Spark DataFrame in multiple batches based on the 'batch' column values.
@@ -26,6 +26,10 @@ def ingest_spark_dataframe(
2626 options : Dict[str, Any]
2727 Dictionary of options to configure the DataFrame writer.
2828 Refer to example for more information.
29+ num_groups: Optional[int], optional
30+ The number of partitions to split Spark DataFrame into.
31+ If not provided, then will be calculated.
32+ It is more efficient to pass this parameter explicitly. By default None
2933
3034 Example
3135 -------
@@ -67,9 +71,10 @@ def ingest_spark_dataframe(
6771 for batch_value in batch_list
6872 ]
6973
74+ num_groups = num_groups or spark_dataframe .select ("group" ).distinct ().count ()
75+
7076 # write batches serially to Neo4j database
7177 for batch in batches :
72- num_groups = batch .select ("group" ).distinct ().count ()
7378 (
7479 batch .repartition (num_groups , "group" ) # define parallel groups for ingest
7580 .write .mode (save_mode )
0 commit comments