Skip to content

Commit b41689e

Browse files
committed
update
1 parent 942f24b commit b41689e

File tree

2 files changed

+67
-12
lines changed

2 files changed

+67
-12
lines changed

async-openai-wasm/src/client.rs

Lines changed: 65 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -446,10 +446,10 @@ impl<C: Config> Client<C> {
446446
path: &str,
447447
request: I,
448448
event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,
449-
) -> Pin<Box<dyn Stream<Item=Result<O, OpenAIError>> + Send>>
449+
) -> OpenAIEventMappedStream<O>
450450
where
451451
I: Serialize,
452-
O: DeserializeOwned + Send + 'static,
452+
O: DeserializeOwned + Send + 'static
453453
{
454454
let event_source = self
455455
.http_client
@@ -460,8 +460,7 @@ impl<C: Config> Client<C> {
460460
.eventsource()
461461
.unwrap();
462462

463-
// stream_mapped_raw_events(event_source, event_mapper).await
464-
todo!()
463+
OpenAIEventMappedStream::new(event_source, event_mapper)
465464
}
466465

467466
/// Make HTTP GET request to receive SSE
@@ -491,13 +490,13 @@ impl<C: Config> Client<C> {
491490
/// Request which responds with SSE.
492491
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
493492
#[pin_project]
494-
pub struct OpenAIEventStream<O> {
493+
pub struct OpenAIEventStream<O: DeserializeOwned + Send + 'static> {
495494
#[pin]
496495
stream: Filter<EventSource, future::Ready<bool>, fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>>,
497496
_phantom_data: PhantomData<O>,
498497
}
499498

500-
impl<O> OpenAIEventStream<O> {
499+
impl<O: DeserializeOwned + Send + 'static> OpenAIEventStream<O> {
501500
pub(crate) fn new(event_source: EventSource) -> Self {
502501
Self {
503502
stream: event_source.filter(|result|
@@ -543,6 +542,66 @@ impl<O: DeserializeOwned + Send + 'static> Stream for OpenAIEventStream<O> {
543542
}
544543
}
545544

545+
#[pin_project]
546+
pub struct OpenAIEventMappedStream<O>
547+
where O: Send + 'static
548+
{
549+
#[pin]
550+
stream: Filter<EventSource, future::Ready<bool>, fn(&Result<Event, reqwest_eventsource::Error>) -> future::Ready<bool>>,
551+
event_mapper: Box<dyn Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static>,
552+
_phantom_data: PhantomData<O>,
553+
}
554+
555+
impl<O> OpenAIEventMappedStream<O>
556+
where O: Send + 'static
557+
{
558+
pub(crate) fn new<M>(event_source: EventSource, event_mapper: M) -> Self
559+
where M: Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static {
560+
Self {
561+
stream: event_source.filter(|result|
562+
// filter out the first event which is always Event::Open
563+
future::ready(!(result.is_ok() && result.as_ref().unwrap().eq(&Event::Open)))
564+
),
565+
event_mapper: Box::new(event_mapper),
566+
_phantom_data: PhantomData,
567+
}
568+
}
569+
}
570+
571+
572+
impl<O> Stream for OpenAIEventMappedStream<O>
573+
where O: Send + 'static
574+
{
575+
type Item = Result<O, OpenAIError>;
576+
577+
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
578+
let this = self.project();
579+
let stream: Pin<&mut _> = this.stream;
580+
match stream.poll_next(cx) {
581+
Poll::Ready(response) => {
582+
match response {
583+
None => Poll::Ready(None), // end of the stream
584+
Some(result) => match result {
585+
Ok(event) => match event {
586+
Event::Open => unreachable!(), // it has been filtered out
587+
Event::Message(message) => {
588+
if message.data == "[DONE]" {
589+
Poll::Ready(None) // end of the stream, defined by OpenAI
590+
} else {
591+
todo!()
592+
}
593+
}
594+
}
595+
Err(e) => Poll::Ready(Some(Err(OpenAIError::StreamError(e.to_string()))))
596+
}
597+
}
598+
}
599+
Poll::Pending => Poll::Pending
600+
}
601+
}
602+
}
603+
604+
546605
// pub(crate) async fn stream_mapped_raw_events<O>(
547606
// mut event_source: EventSource,
548607
// event_mapper: impl Fn(eventsource_stream::Event) -> Result<O, OpenAIError> + Send + 'static,

async-openai-wasm/src/types/assistant_stream.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
use std::pin::Pin;
2-
3-
use futures::Stream;
41
use serde::Deserialize;
52

3+
use crate::client::OpenAIEventMappedStream;
64
use crate::error::{ApiError, map_deserialization_error, OpenAIError};
75

86
use super::{
@@ -28,7 +26,6 @@ use super::{
2826
/// We may add additional events over time, so we recommend handling unknown events gracefully
2927
/// in your code. See the [Assistants API quickstart](https://platform.openai.com/docs/assistants/overview) to learn how to
3028
/// integrate the Assistants API with streaming.
31-
3229
#[derive(Debug, Deserialize, Clone)]
3330
#[serde(tag = "event", content = "data")]
3431
#[non_exhaustive]
@@ -110,8 +107,7 @@ pub enum AssistantStreamEvent {
110107
Done(String),
111108
}
112109

113-
pub type AssistantEventStream =
114-
Pin<Box<dyn Stream<Item = Result<AssistantStreamEvent, OpenAIError>> + Send>>;
110+
pub type AssistantEventStream = OpenAIEventMappedStream<AssistantStreamEvent>;
115111

116112
impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
117113
type Error = OpenAIError;

0 commit comments

Comments
 (0)