Skip to content

Commit eeedb94

Browse files
committed
add extra guard on OpenAIEventStream::poll_next
1 parent 2e8a62b commit eeedb94

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

async-openai-wasm/src/client.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ impl<C: Config> Client<C> {
493493
pub struct OpenAIEventStream<O: DeserializeOwned + Send + 'static> {
494494
#[pin]
495495
stream: Filter<EventSource, future::Ready<bool>, fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>>,
496+
done: bool,
496497
_phantom_data: PhantomData<O>,
497498
}
498499

@@ -503,6 +504,7 @@ impl<O: DeserializeOwned + Send + 'static> OpenAIEventStream<O> {
503504
// filter out the first event which is always Event::Open
504505
future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))
505506
),
507+
done: false,
506508
_phantom_data: PhantomData,
507509
}
508510
}
@@ -513,6 +515,9 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
513515

514516
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
515517
let this = self.project();
518+
if *this.done {
519+
return Poll::Ready(None);
520+
}
516521
let stream: Pin<&mut _> = this.stream;
517522
match stream.poll_next(cx) {
518523
Poll::Ready(response) => {
@@ -523,17 +528,24 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
523528
Event::Open => unreachable!(), // it has been filtered out
524529
Event::Message(message) => {
525530
if message.data == "[DONE]" {
531+
*this.done = true;
526532
Poll::Ready(None) // end of the stream, defined by OpenAI
527533
} else {
528534
// deserialize the data
529535
match serde_json::from_str::<O>(&message.data) {
530-
Err(e) => Poll::Ready(Some(Err(map_deserialization_error(e, &message.data.as_bytes())))),
536+
Err(e) => {
537+
*this.done = true;
538+
Poll::Ready(Some(Err(map_deserialization_error(e, &message.data.as_bytes()))))
539+
}
531540
Ok(output) => Poll::Ready(Some(Ok(output))),
532541
}
533542
}
534543
}
535544
}
536-
Err(e) => Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
545+
Err(e) => {
546+
*this.done = true;
547+
Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
548+
}
537549
}
538550
}
539551
}

0 commit comments

Comments
 (0)