@@ -68,7 +68,8 @@ def __init__(self, properties: QueryProperties, **parameters):
6868 self .column_map , self .aggregate_functions = build_aggregations (self .aggregates )
6969
7070 self .buffer = []
71- self .max_buffer_size = 50 # Process in chunks to avoid excessive memory usage
71+ self .max_buffer_size = 100 # Process in chunks to avoid excessive memory usage
72+ self ._partial_aggregated = False # Track if we've done a partial aggregation
7273
7374 @property
7475 def config (self ): # pragma: no cover
@@ -86,38 +87,122 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
8687 yield EOS
8788 return
8889
89- # If we have partial results in buffer, do final aggregation
90- if len (self .buffer ) > 0 :
91- table = pyarrow .concat_tables (
92- self .buffer ,
93- promote_options = "permissive" ,
94- )
90+ # Do final aggregation if we have buffered data
91+ table = pyarrow .concat_tables (
92+ self .buffer ,
93+ promote_options = "permissive" ,
94+ )
95+ # Only combine chunks if we haven't done partial aggregation yet
96+ # combine_chunks can fail after partial aggregation due to buffer structure
97+ if not self ._partial_aggregated :
9598 table = table .combine_chunks ()
99+
100+ # If we've done partial aggregations, the aggregate functions need adjusting
101+ # because columns like "*" have been renamed to "*_count"
102+ if self ._partial_aggregated :
103+ # Build new aggregate functions for re-aggregating partial results
104+ adjusted_aggs = []
105+ adjusted_column_map = {}
106+
107+ for field_name , function , _count_options in self .aggregate_functions :
108+ # For COUNT aggregates, the column is now named "*_count" and we need to SUM it
109+ if function == "count" :
110+ renamed_field = f"{ field_name } _count"
111+ adjusted_aggs .append ((renamed_field , "sum" , None ))
112+ # The final column will be named "*_count_sum", need to track for renaming
113+ for orig_name , mapped_name in self .column_map .items ():
114+ if mapped_name == f"{ field_name } _count" :
115+ adjusted_column_map [orig_name ] = f"{ renamed_field } _sum"
116+ # For other aggregates, we can re-aggregate with the same function
117+ else :
118+ renamed_field = f"{ field_name } _{ function } " .replace ("_hash_" , "_" )
119+ # Some aggregates can be re-aggregated (sum, max, min)
120+ if function in ("sum" , "max" , "min" , "hash_one" , "all" , "any" ):
121+ adjusted_aggs .append ((renamed_field , function , None ))
122+ # Track the mapping: original -> intermediate -> final
123+ for orig_name , mapped_name in self .column_map .items ():
124+ if mapped_name == renamed_field :
125+ # sum->sum, max->max, etc. means same name
126+ adjusted_column_map [orig_name ] = (
127+ f"{ renamed_field } _{ function } " .replace ("_hash_" , "_" )
128+ )
129+ elif function == "mean" :
130+ # For mean, just take one of the existing values (not ideal)
131+ adjusted_aggs .append ((renamed_field , "hash_one" , None ))
132+ for orig_name , mapped_name in self .column_map .items ():
133+ if mapped_name == renamed_field :
134+ adjusted_column_map [orig_name ] = f"{ renamed_field } _one"
135+ elif function == "hash_list" :
136+ # For ARRAY_AGG, we need to flatten lists
137+ adjusted_aggs .append ((renamed_field , "hash_list" , None ))
138+ for orig_name , mapped_name in self .column_map .items ():
139+ if mapped_name == renamed_field :
140+ adjusted_column_map [orig_name ] = f"{ renamed_field } _list"
141+ else :
142+ # For other aggregates, take one value
143+ adjusted_aggs .append ((renamed_field , "hash_one" , None ))
144+ for orig_name , mapped_name in self .column_map .items ():
145+ if mapped_name == renamed_field :
146+ adjusted_column_map [orig_name ] = f"{ renamed_field } _one"
147+
148+ groups = table .group_by (self .group_by_columns )
149+ groups = groups .aggregate (adjusted_aggs )
150+
151+ # Use the adjusted column map for selecting/renaming
152+ groups = groups .select (list (adjusted_column_map .values ()) + self .group_by_columns )
153+ groups = groups .rename_columns (
154+ list (adjusted_column_map .keys ()) + self .group_by_columns
155+ )
156+ else :
96157 groups = table .group_by (self .group_by_columns )
97158 groups = groups .aggregate (self .aggregate_functions )
98- self .buffer = [groups ] # Replace buffer with final result
99-
100- # Now buffer has the final aggregated result
101- groups = self .buffer [0 ]
102-
103- # do the secondary activities for ARRAY_AGG
104- for node in get_all_nodes_of_type (self .aggregates , select_nodes = (NodeType .AGGREGATOR ,)):
105- if node .value == "ARRAY_AGG" and node .order or node .limit :
106- # rip the column out of the table
107- column_name = self .column_map [node .schema_column .identity ]
108- column_def = groups .field (column_name ) # this is used
109- column = groups .column (column_name ).to_pylist ()
110- groups = groups .drop ([column_name ])
159+
160+ # project to the desired column names from the pyarrow names
161+ groups = groups .select (list (self .column_map .values ()) + self .group_by_columns )
162+ groups = groups .rename_columns (list (self .column_map .keys ()) + self .group_by_columns )
163+
164+ # do the secondary activities for ARRAY_AGG (order and limit)
165+ array_agg_nodes = [
166+ node
167+ for node in get_all_nodes_of_type (
168+ self .aggregates , select_nodes = (NodeType .AGGREGATOR ,)
169+ )
170+ if node .value == "ARRAY_AGG" and (node .order or node .limit )
171+ ]
172+
173+ if array_agg_nodes :
174+ # Process all ARRAY_AGG columns that need ordering/limiting
175+ arrays_to_update = {}
176+ field_defs = {}
177+
178+ for node in array_agg_nodes :
179+ column_name = node .schema_column .identity
180+
181+ # Store field definition before we drop the column
182+ field_defs [column_name ] = groups .field (column_name )
183+
184+ # Extract and process the data
185+ column_data = groups .column (column_name ).to_pylist ()
186+
187+ # Apply ordering if specified
111188 if node .order :
112- column = [sorted (c , reverse = bool (node .order [0 ][1 ])) for c in column ]
189+ column_data = [
190+ sorted (c , reverse = bool (node .order [0 ][1 ])) for c in column_data
191+ ]
192+
193+ # Apply limit if specified
113194 if node .limit :
114- column = [c [: node .limit ] for c in column ]
115- # put the new column into the table
116- groups = groups .append_column (column_def , [column ])
195+ column_data = [c [: node .limit ] for c in column_data ]
196+
197+ arrays_to_update [column_name ] = column_data
198+
199+ # Drop all columns we're updating
200+ columns_to_drop = list (arrays_to_update .keys ())
201+ groups = groups .drop (columns_to_drop )
117202
118- # project to the desired column names from the pyarrow names
119- groups = groups . select ( list ( self . column_map . values ()) + self . group_by_columns )
120- groups = groups .rename_columns ( list ( self . column_map . keys ()) + self . group_by_columns )
203+ # Append all updated columns back
204+ for column_name , column_data in arrays_to_update . items ():
205+ groups = groups .append_column ( field_defs [ column_name ], [ column_data ] )
121206
122207 num_rows = groups .num_rows
123208 for start in range (0 , num_rows , CHUNK_SIZE ):
@@ -128,9 +213,10 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
128213
129214 morsel = project (morsel , self .all_identifiers )
130215 # Add a "*" column, this is an int because when a bool it miscounts
216+ # FIX: Use int8 as the comment states (bool can miscount)
131217 if "*" not in morsel .column_names :
132218 morsel = morsel .append_column (
133- "*" , [numpy .ones (shape = morsel .num_rows , dtype = numpy .bool_ )]
219+ "*" , [numpy .ones (shape = morsel .num_rows , dtype = numpy .int8 )]
134220 )
135221 if self .evaluatable_nodes :
136222 morsel = evaluate_and_append (self .evaluatable_nodes , morsel )
@@ -144,9 +230,11 @@ def execute(self, morsel: pyarrow.Table, **kwargs):
144230 self .buffer ,
145231 promote_options = "permissive" ,
146232 )
233+ # Only combine chunks once before aggregation
147234 table = table .combine_chunks ()
148235 groups = table .group_by (self .group_by_columns )
149236 groups = groups .aggregate (self .aggregate_functions )
150237 self .buffer = [groups ] # Replace buffer with partial result
238+ self ._partial_aggregated = True # Mark that we've done a partial aggregation
151239
152240 yield None
0 commit comments