Skip to content

Commit 2e8a62b

Browse files
committed
finish Stream impl for OpenAIEventMappedStream
1 parent b41689e commit 2e8a62b

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

async-openai-wasm/src/client.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,7 @@ pub struct OpenAIEventMappedStream<O>
549549
#[pin]
550550
stream: Filter<EventSource, future::Ready<bool>, fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>>,
551551
event_mapper: Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>,
552+
done: bool,
552553
_phantom_data: PhantomData<O>,
553554
}
554555

@@ -562,6 +563,7 @@ impl<O> OpenAIEventMappedStream<O>
562563
// filter out the first event which is always Event::Open
563564
future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))
564565
),
566+
done: false,
565567
event_mapper: Box::new(event_mapper),
566568
_phantom_data: PhantomData,
567569
}
@@ -574,8 +576,12 @@ impl<O> Stream for OpenAIEventMappedStream<O>
574576
{
575577
type Item = Result<O, OpenAIError>;
576578

579+
// TODO: test this
577580
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
578581
let this = self.project();
582+
if *this.done {
583+
return Poll::Ready(None);
584+
}
579585
let stream: Pin<&mut _> = this.stream;
580586
match stream.poll_next(cx) {
581587
Poll::Ready(response) => {
@@ -586,13 +592,19 @@ impl<O> Stream for OpenAIEventMappedStream<O>
586592
Event::Open => unreachable!(), // it has been filtered out
587593
Event::Message(message) => {
588594
if message.data == "[DONE]" {
589-
Poll::Ready(None) // end of the stream, defined by OpenAI
590-
} else {
591-
todo!()
595+
*this.done = true;
596+
}
597+
let response = (this.event_mapper)(message);
598+
match response {
599+
Ok(output) => Poll::Ready(Some(Ok(output))),
600+
Err(_) => Poll::Ready(None)
592601
}
593602
}
594603
}
595-
Err(e) => Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
604+
Err(e) => {
605+
*this.done = true;
606+
Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
607+
}
596608
}
597609
}
598610
}

0 commit comments

Comments
 (0)