@@ -446,10 +446,10 @@ impl<C: Config> Client<C> {
446
446
path : & str ,
447
447
request : I ,
448
448
event_mapper : impl Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static ,
449
- ) -> Pin < Box < dyn Stream < Item = Result < O , OpenAIError > > + Send > >
449
+ ) -> OpenAIEventMappedStream < O >
450
450
where
451
451
I : Serialize ,
452
- O : DeserializeOwned + Send + ' static ,
452
+ O : DeserializeOwned + Send + ' static
453
453
{
454
454
let event_source = self
455
455
. http_client
@@ -460,8 +460,7 @@ impl<C: Config> Client<C> {
460
460
. eventsource ( )
461
461
. unwrap ( ) ;
462
462
463
- // stream_mapped_raw_events(event_source, event_mapper).await
464
- todo ! ( )
463
+ OpenAIEventMappedStream :: new ( event_source, event_mapper)
465
464
}
466
465
467
466
/// Make HTTP GET request to receive SSE
@@ -491,13 +490,13 @@ impl<C: Config> Client<C> {
491
490
/// Request which responds with SSE.
492
491
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
493
492
#[ pin_project]
494
- pub struct OpenAIEventStream < O > {
493
+ pub struct OpenAIEventStream < O : DeserializeOwned + Send + ' static > {
495
494
#[ pin]
496
495
stream : Filter < EventSource , future:: Ready < bool > , fn ( & Result < Event , reqwest_eventsource:: Error > ) -> future:: Ready < bool > > ,
497
496
_phantom_data : PhantomData < O > ,
498
497
}
499
498
500
- impl < O > OpenAIEventStream < O > {
499
+ impl < O : DeserializeOwned + Send + ' static > OpenAIEventStream < O > {
501
500
pub ( crate ) fn new ( event_source : EventSource ) -> Self {
502
501
Self {
503
502
stream : event_source. filter ( |result|
@@ -543,6 +542,66 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
543
542
}
544
543
}
545
544
545
+ #[ pin_project]
546
+ pub struct OpenAIEventMappedStream < O >
547
+ where O : Send + ' static
548
+ {
549
+ #[ pin]
550
+ stream : Filter < EventSource , future:: Ready < bool > , fn ( & Result < Event , reqwest_eventsource:: Error > ) -> future:: Ready < bool > > ,
551
+ event_mapper : Box < dyn Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static > ,
552
+ _phantom_data : PhantomData < O > ,
553
+ }
554
+
555
+ impl < O > OpenAIEventMappedStream < O >
556
+ where O : Send + ' static
557
+ {
558
+ pub ( crate ) fn new < M > ( event_source : EventSource , event_mapper : M ) -> Self
559
+ where M : Fn ( eventsource_stream:: Event ) -> Result < O , OpenAIError > + Send + ' static {
560
+ Self {
561
+ stream : event_source. filter ( |result|
562
+ // filter out the first event which is always Event::Open
563
+ future:: ready ( !( result. is_ok ( ) && result. as_ref ( ) . unwrap ( ) . eq ( & Event :: Open ) ) )
564
+ ) ,
565
+ event_mapper : Box :: new ( event_mapper) ,
566
+ _phantom_data : PhantomData ,
567
+ }
568
+ }
569
+ }
570
+
571
+
572
+ impl < O > Stream for OpenAIEventMappedStream < O >
573
+ where O : Send + ' static
574
+ {
575
+ type Item = Result < O , OpenAIError > ;
576
+
577
+ fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
578
+ let this = self . project ( ) ;
579
+ let stream: Pin < & mut _ > = this. stream ;
580
+ match stream. poll_next ( cx) {
581
+ Poll :: Ready ( response) => {
582
+ match response {
583
+ None => Poll :: Ready ( None ) , // end of the stream
584
+ Some ( result) => match result {
585
+ Ok ( event) => match event {
586
+ Event :: Open => unreachable ! ( ) , // it has been filtered out
587
+ Event :: Message ( message) => {
588
+ if message. data == "[DONE]" {
589
+ Poll :: Ready ( None ) // end of the stream, defined by OpenAI
590
+ } else {
591
+ todo ! ( )
592
+ }
593
+ }
594
+ }
595
+ Err ( e) => Poll :: Ready ( Some ( Err ( OpenAIError :: StreamError ( e. to_string ( ) ) ) ) )
596
+ }
597
+ }
598
+ }
599
+ Poll :: Pending => Poll :: Pending
600
+ }
601
+ }
602
+ }
603
+
604
+
546
605
// pub(crate) async fn stream_mapped_raw_events<O>(
547
606
// mut event_source: EventSource,
548
607
// event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
0 commit comments