@@ -1236,81 +1236,59 @@ impl Stream for RepartitionStream {
12361236 let poll = loop {
12371237 match & mut this. state {
12381238 RepartitionStreamState :: ReceivingFromChannel => {
1239- match this. input . recv ( ) . poll_unpin ( cx) {
1240- Poll :: Ready ( value) => match value {
1241- Some ( Some ( v) ) => match v {
1239+ let recv = match this. input . recv ( ) . poll_unpin ( cx) {
1240+ Poll :: Pending => break Poll :: Pending ,
1241+ Poll :: Ready ( value) => value,
1242+ } ;
1243+
1244+ match recv {
1245+ Some ( Some ( result) ) => {
1246+ match result {
12421247 Ok ( RepartitionBatch :: Memory ( batch) ) => {
1243- let _timer =
1244- this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
1245- // Release memory and return
1246- this. reservation
1247- . lock ( )
1248- . shrink ( batch. get_array_memory_size ( ) ) ;
1249- break Poll :: Ready ( Some ( Ok ( batch) ) ) ;
1248+ let bytes = batch. get_array_memory_size ( ) ;
1249+ break this. ready_with_batch ( batch, Some ( bytes) ) ;
12501250 }
12511251 Ok ( RepartitionBatch :: Spilled { spill_file, size } ) => {
1252- // Read from disk - SpillReaderStream uses tokio::fs internally
1253- // Pass the original size for validation
12541252 match this
12551253 . spill_manager
12561254 . read_spill_as_stream ( spill_file, Some ( size) )
12571255 {
12581256 Ok ( stream) => {
12591257 this. state = RepartitionStreamState :: ReadingSpilledBatch ( stream) ;
1260- // Continue loop to poll the stream immediately
12611258 continue ;
12621259 }
1263- Err ( e) => {
1264- let _timer = this
1265- . baseline_metrics
1266- . elapsed_compute ( )
1267- . timer ( ) ;
1268- break Poll :: Ready ( Some ( Err ( e) ) ) ;
1269- }
1260+ Err ( err) => break this. ready_with_error ( err) ,
12701261 }
12711262 }
1272- Err ( e) => {
1273- let _timer =
1274- this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
1275- break Poll :: Ready ( Some ( Err ( e) ) ) ;
1276- }
1277- } ,
1278- Some ( None ) => {
1279- this. num_input_partitions_processed += 1 ;
1280-
1281- if this. num_input_partitions
1282- == this. num_input_partitions_processed
1283- {
1284- // all input partitions have finished sending batches
1285- break Poll :: Ready ( None ) ;
1286- } else {
1287- // other partitions still have data to send
1288- continue ;
1289- }
1263+ Err ( err) => break this. ready_with_error ( err) ,
12901264 }
1291- None => break Poll :: Ready ( None ) ,
1292- } ,
1293- Poll :: Pending => break Poll :: Pending ,
1265+ }
1266+ Some ( None ) => {
1267+ this. num_input_partitions_processed += 1 ;
1268+ if this. num_input_partitions
1269+ == this. num_input_partitions_processed
1270+ {
1271+ break Poll :: Ready ( None ) ;
1272+ }
1273+ continue ;
1274+ }
1275+ None => break Poll :: Ready ( None ) ,
12941276 }
12951277 }
12961278 RepartitionStreamState :: ReadingSpilledBatch ( stream) => {
12971279 match stream. poll_next_unpin ( cx) {
1280+ Poll :: Pending => break Poll :: Pending ,
12981281 Poll :: Ready ( Some ( Ok ( batch) ) ) => {
1299- let _timer = this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
1300- // Return batch and stay in ReadingSpilledBatch state to read more batches
1301- break Poll :: Ready ( Some ( Ok ( batch) ) ) ;
1282+ break this. ready_with_batch ( batch, None )
13021283 }
1303- Poll :: Ready ( Some ( Err ( e) ) ) => {
1304- let _timer = this. baseline_metrics . elapsed_compute ( ) . timer ( ) ;
1284+ Poll :: Ready ( Some ( Err ( err) ) ) => {
13051285 this. state = RepartitionStreamState :: ReceivingFromChannel ;
1306- break Poll :: Ready ( Some ( Err ( e ) ) ) ;
1286+ break this . ready_with_error ( err ) ;
13071287 }
13081288 Poll :: Ready ( None ) => {
1309- // Spill stream ended - go back to receiving from channel
13101289 this. state = RepartitionStreamState :: ReceivingFromChannel ;
13111290 continue ;
13121291 }
1313- Poll :: Pending => break Poll :: Pending ,
13141292 }
13151293 }
13161294 }
@@ -1320,6 +1298,28 @@ impl Stream for RepartitionStream {
13201298 }
13211299}
13221300
1301+ impl RepartitionStream {
1302+ fn ready_with_batch (
1303+ & mut self ,
1304+ batch : RecordBatch ,
1305+ released_bytes : Option < usize > ,
1306+ ) -> Poll < Option < Result < RecordBatch > > > {
1307+ if let Some ( bytes) = released_bytes {
1308+ self . reservation . lock ( ) . shrink ( bytes) ;
1309+ }
1310+ let _timer = self . baseline_metrics . elapsed_compute ( ) . timer ( ) ;
1311+ Poll :: Ready ( Some ( Ok ( batch) ) )
1312+ }
1313+
1314+ fn ready_with_error (
1315+ & mut self ,
1316+ err : DataFusionError ,
1317+ ) -> Poll < Option < Result < RecordBatch > > > {
1318+ let _timer = self . baseline_metrics . elapsed_compute ( ) . timer ( ) ;
1319+ Poll :: Ready ( Some ( Err ( err) ) )
1320+ }
1321+ }
1322+
13231323impl RecordBatchStream for RepartitionStream {
13241324 /// Get the schema
13251325 fn schema ( & self ) -> SchemaRef {
0 commit comments