Skip to content

Shubhra/ajs 37 refactor tts with streams #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: dev-1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions agents/src/tts/stream_adapter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ export class StreamAdapterWrapper extends SynthesizeStream {

async #run() {
const forwardInput = async () => {
for await (const input of this.input) {
while (true) {
const { done, value: input } = await this.inputReader.read();
if (done) break;
if (input === SynthesizeStream.FLUSH_SENTINEL) {
this.#sentenceStream.flush();
} else {
Expand All @@ -65,10 +67,10 @@ export class StreamAdapterWrapper extends SynthesizeStream {
const synthesize = async () => {
for await (const ev of this.#sentenceStream) {
for await (const audio of this.#tts.synthesize(ev.token)) {
this.output.put(audio);
this.outputWriter.write(audio);
}
}
this.output.put(SynthesizeStream.END_OF_STREAM);
this.outputWriter.write(SynthesizeStream.END_OF_STREAM);
};

Promise.all([forwardInput(), synthesize()]);
Expand Down
148 changes: 115 additions & 33 deletions agents/src/tts/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,12 @@
import type { AudioFrame } from '@livekit/rtc-node';
import type { TypedEventEmitter as TypedEmitter } from '@livekit/typed-emitter';
import { EventEmitter } from 'node:events';
import type { ReadableStream } from 'node:stream/web';
import { log } from '../log.js';
import type { TTSMetrics } from '../metrics/base.js';
import { AsyncIterableQueue, mergeFrames } from '../utils.js';
import { DeferredReadableStream } from '../stream/deferred_stream.js';
import { IdentityTransform } from '../stream/identity_transform.js';
import { mergeFrames } from '../utils.js';

/** SynthesizedAudio is a packet of speech synthesis as returned by the TTS. */
export interface SynthesizedAudio {
Expand Down Expand Up @@ -105,22 +109,73 @@ export abstract class SynthesizeStream
{
protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL');
static readonly END_OF_STREAM = Symbol('END_OF_STREAM');
protected input = new AsyncIterableQueue<string | typeof SynthesizeStream.FLUSH_SENTINEL>();
protected queue = new AsyncIterableQueue<
protected inputReader: ReadableStreamDefaultReader<
string | typeof SynthesizeStream.FLUSH_SENTINEL
>;
protected outputWriter: WritableStreamDefaultWriter<
SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM
>();
protected output = new AsyncIterableQueue<
SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM
>();
>;
protected closed = false;
abstract label: string;
#tts: TTS;
#metricsPendingTexts: string[] = [];
#metricsText = '';
#monitorMetricsTask?: Promise<void>;

private deferredInputStream: DeferredReadableStream<
string | typeof SynthesizeStream.FLUSH_SENTINEL
>;
private metricsStream: ReadableStream<SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM>;
private input = new IdentityTransform<string | typeof SynthesizeStream.FLUSH_SENTINEL>();
private output = new IdentityTransform<
SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM
>();
private inputWriter: WritableStreamDefaultWriter<string | typeof SynthesizeStream.FLUSH_SENTINEL>;
private outputReader: ReadableStreamDefaultReader<
SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM
>;
private logger = log();
private inputClosed = false;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicative. The Readable/WritableStreamDefaultWriter internally tracks whether it's closed but doesn't expose it. The only way to know is when you try to write to or close an already-closed writer, it throws an error.

The other option would be to

    try {
      this.inputWriter.write(SynthesizeStream.FLUSH_SENTINEL);
    } catch (error) {
      throw new Error('Input is closed');
    }

everywhere, which doesn't seem much better. Let me know what you thoughts are @lukasIO @toubatbrian


constructor(tts: TTS) {
this.#tts = tts;
this.deferredInputStream = new DeferredReadableStream();

this.inputWriter = this.input.writable.getWriter();
this.inputReader = this.input.readable.getReader();
this.outputWriter = this.output.writable.getWriter();

const [outputStream, metricsStream] = this.output.readable.tee();
this.outputReader = outputStream.getReader();
this.metricsStream = metricsStream;

this.pumpDeferredStream();
this.monitorMetrics();
}

/**
* Reads from the deferred input stream and forwards chunks to the input writer.
*
* Note: we can't just do this.deferredInputStream.stream.pipeTo(this.input.writable)
* because the inputWriter locks the this.input.writable stream. All writes must go through
* the inputWriter.
*/
private async pumpDeferredStream() {
const reader = this.deferredInputStream.stream.getReader();
try {
while (true) {
const { done, value } = await reader.read();
if (done || value === SynthesizeStream.FLUSH_SENTINEL) {
break;
}
this.inputWriter.write(value);
}
} catch (error) {
this.logger.error(error, 'Error reading deferred input stream');
} finally {
reader.releaseLock();
this.flush();
this.endInput();
}
}

protected async monitorMetrics() {
Expand Down Expand Up @@ -148,9 +203,11 @@ export abstract class SynthesizeStream
}
};

for await (const audio of this.queue) {
this.output.put(audio);
if (audio === SynthesizeStream.END_OF_STREAM) continue;
const metricsReader = this.metricsStream.getReader();

while (true) {
const { done, value: audio } = await metricsReader.read();
if (done || audio === SynthesizeStream.END_OF_STREAM) break;
requestId = audio.requestId;
if (!ttfb) {
ttfb = process.hrtime.bigint() - startTime;
Expand All @@ -164,23 +221,24 @@ export abstract class SynthesizeStream
if (requestId) {
emit();
}
this.output.close();
}

updateInputStream(text: ReadableStream<string>) {
this.deferredInputStream.setSource(text);
}

/** Push a string of text to the TTS */
/** @deprecated Use `updateInputStream` instead */
pushText(text: string) {
if (!this.#monitorMetricsTask) {
this.#monitorMetricsTask = this.monitorMetrics();
}
this.#metricsText += text;

if (this.input.closed) {
if (this.inputClosed) {
throw new Error('Input is closed');
}
if (this.closed) {
throw new Error('Stream is closed');
}
this.input.put(text);
this.inputWriter.write(text);
}

/** Flush the TTS, causing it to process all pending text */
Expand All @@ -189,34 +247,41 @@ export abstract class SynthesizeStream
this.#metricsPendingTexts.push(this.#metricsText);
this.#metricsText = '';
}
if (this.input.closed) {
if (this.inputClosed) {
throw new Error('Input is closed');
}
if (this.closed) {
throw new Error('Stream is closed');
}
this.input.put(SynthesizeStream.FLUSH_SENTINEL);
this.inputWriter.write(SynthesizeStream.FLUSH_SENTINEL);
}

/** Mark the input as ended and forbid additional pushes */
endInput() {
if (this.input.closed) {
if (this.inputClosed) {
throw new Error('Input is closed');
}
if (this.closed) {
throw new Error('Stream is closed');
}
this.input.close();
this.inputClosed = true;
this.inputWriter.close();
}

next(): Promise<IteratorResult<SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM>> {
return this.output.next();
return this.outputReader.read().then(({ done, value }) => {
if (done) {
return { done: true, value: undefined };
}
return { done: false, value };
});
}

/** Close both the input and output of the TTS stream */
close() {
this.input.close();
this.output.close();
if (!this.inputClosed) {
this.inputWriter.close();
}
this.closed = true;
}

Expand All @@ -240,17 +305,26 @@ export abstract class SynthesizeStream
* exports its own child ChunkedStream class, which inherits this class's methods.
*/
export abstract class ChunkedStream implements AsyncIterableIterator<SynthesizedAudio> {
protected queue = new AsyncIterableQueue<SynthesizedAudio>();
protected output = new AsyncIterableQueue<SynthesizedAudio>();
protected outputWriter: WritableStreamDefaultWriter<
SynthesizedAudio | typeof SynthesizeStream.END_OF_STREAM
>;
protected closed = false;
abstract label: string;
#text: string;
#tts: TTS;
private output = new IdentityTransform<SynthesizedAudio>();
private outputReader: ReadableStreamDefaultReader<SynthesizedAudio>;
private metricsStream: ReadableStream<SynthesizedAudio>;

constructor(text: string, tts: TTS) {
this.#text = text;
this.#tts = tts;

this.outputWriter = this.output.writable.getWriter();
const [outputStream, metricsStream] = this.output.readable.tee();
this.outputReader = outputStream.getReader();
this.metricsStream = metricsStream;

this.monitorMetrics();
}

Expand All @@ -260,15 +334,18 @@ export abstract class ChunkedStream implements AsyncIterableIterator<Synthesized
let ttfb: bigint | undefined;
let requestId = '';

for await (const audio of this.queue) {
this.output.put(audio);
const metricsReader = this.metricsStream.getReader();

while (true) {
const { done, value: audio } = await metricsReader.read();
if (done) break;

requestId = audio.requestId;
if (!ttfb) {
ttfb = process.hrtime.bigint() - startTime;
}
audioDuration += audio.frame.samplesPerChannel / audio.frame.sampleRate;
}
this.output.close();

const duration = process.hrtime.bigint() - startTime;
const metrics: TTSMetrics = {
Expand All @@ -294,14 +371,19 @@ export abstract class ChunkedStream implements AsyncIterableIterator<Synthesized
return mergeFrames(frames);
}

next(): Promise<IteratorResult<SynthesizedAudio>> {
return this.output.next();
async next(): Promise<IteratorResult<SynthesizedAudio>> {
const { done, value } = await this.outputReader.read();
if (done) {
return { done: true, value: undefined };
}
return { done: false, value };
}

/** Close both the input and output of the TTS stream */
close() {
this.queue.close();
this.output.close();
if (!this.closed) {
this.outputWriter.close();
}
this.closed = true;
}

Expand Down
17 changes: 10 additions & 7 deletions agents/src/vad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,22 @@ export abstract class VAD extends (EventEmitter as new () => TypedEmitter<VADCal

export abstract class VADStream implements AsyncIterableIterator<VADEvent> {
protected static readonly FLUSH_SENTINEL = Symbol('FLUSH_SENTINEL');
protected input = new IdentityTransform<AudioFrame | typeof VADStream.FLUSH_SENTINEL>();
protected output = new IdentityTransform<VADEvent>();
protected inputWriter: WritableStreamDefaultWriter<AudioFrame | typeof VADStream.FLUSH_SENTINEL>;

protected inputReader: ReadableStreamDefaultReader<AudioFrame | typeof VADStream.FLUSH_SENTINEL>;
protected outputWriter: WritableStreamDefaultWriter<VADEvent>;
protected outputReader: ReadableStreamDefaultReader<VADEvent>;
protected closed = false;
protected inputClosed = false;

#vad: VAD;
#lastActivityTime = BigInt(0);
private logger = log();
private deferredInputStream: DeferredReadableStream<AudioFrame>;

private input = new IdentityTransform<AudioFrame | typeof VADStream.FLUSH_SENTINEL>();
private output = new IdentityTransform<VADEvent>();
private metricsStream: ReadableStream<VADEvent>;
private outputReader: ReadableStreamDefaultReader<VADEvent>;
private inputWriter: WritableStreamDefaultWriter<AudioFrame | typeof VADStream.FLUSH_SENTINEL>;
Comment on lines +97 to +101
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making access modifiers more restrictive . Forgot to do it in #390


constructor(vad: VAD) {
this.#vad = vad;
this.deferredInputStream = new DeferredReadableStream<AudioFrame>();
Expand Down Expand Up @@ -207,7 +208,7 @@ export abstract class VADStream implements AsyncIterableIterator<VADEvent> {
throw new Error('Stream is closed');
}
this.inputClosed = true;
this.input.writable.close();
this.inputWriter.close();
}

async next(): Promise<IteratorResult<VADEvent>> {
Expand All @@ -220,7 +221,9 @@ export abstract class VADStream implements AsyncIterableIterator<VADEvent> {
}

close() {
this.input.writable.close();
if (!this.inputClosed) {
this.inputWriter.close();
}
Comment on lines +224 to +226
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

forgot to do in #390

this.closed = true;
}

Expand Down
14 changes: 8 additions & 6 deletions plugins/cartesia/src/tts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ export class ChunkedStream extends tts.ChunkedStream {
(res) => {
res.on('data', (chunk) => {
for (const frame of bstream.write(chunk)) {
this.queue.put({
this.outputWriter.write({
requestId,
frame,
final: false,
Expand All @@ -117,14 +117,14 @@ export class ChunkedStream extends tts.ChunkedStream {
});
res.on('close', () => {
for (const frame of bstream.flush()) {
this.queue.put({
this.outputWriter.write({
requestId,
frame,
final: false,
segmentId: requestId,
});
}
this.queue.close();
this.close();
});
},
);
Expand Down Expand Up @@ -178,7 +178,9 @@ export class SynthesizeStream extends tts.SynthesizeStream {
};

const inputTask = async () => {
for await (const data of this.input) {
while (true) {
const { done, value: data } = await this.inputReader.read();
if (done) break;
if (data === SynthesizeStream.FLUSH_SENTINEL) {
this.#tokenizer.flush();
continue;
Expand All @@ -195,7 +197,7 @@ export class SynthesizeStream extends tts.SynthesizeStream {
let lastFrame: AudioFrame | undefined;
const sendLastFrame = (segmentId: string, final: boolean) => {
if (lastFrame) {
this.queue.put({ requestId, segmentId, frame: lastFrame, final });
this.outputWriter.write({ requestId, segmentId, frame: lastFrame, final });
lastFrame = undefined;
}
};
Expand All @@ -215,7 +217,7 @@ export class SynthesizeStream extends tts.SynthesizeStream {
lastFrame = frame;
}
sendLastFrame(segmentId, true);
this.queue.put(SynthesizeStream.END_OF_STREAM);
this.outputWriter.write(SynthesizeStream.END_OF_STREAM);

if (segmentId === requestId) {
closing = true;
Expand Down
Loading