diff --git a/libs/checkpoint-mongodb/src/index.ts b/libs/checkpoint-mongodb/src/index.ts index 279535086..71051bbb7 100644 --- a/libs/checkpoint-mongodb/src/index.ts +++ b/libs/checkpoint-mongodb/src/index.ts @@ -16,6 +16,7 @@ export type MongoDBSaverParams = { dbName?: string; checkpointCollectionName?: string; checkpointWritesCollectionName?: string; + ttl?: { expireAfterSeconds: number }; }; /** @@ -26,22 +27,57 @@ export class MongoDBSaver extends BaseCheckpointSaver { protected db: MongoDatabase; + protected ttl: { expireAfterSeconds: number } | undefined; + + protected isSetup: boolean; + checkpointCollectionName = "checkpoints"; checkpointWritesCollectionName = "checkpoint_writes"; + async setup(): Promise { + if (this.ttl != null) { + const { expireAfterSeconds } = this.ttl; + await Promise.all([ + this.db + .collection(this.checkpointCollectionName) + .createIndex({ _createdAtForTTL: 1 }, { expireAfterSeconds }), + this.db + .collection(this.checkpointWritesCollectionName) + .createIndex({ _createdAtForTTL: 1 }, { expireAfterSeconds }), + ]); + } + + this.isSetup = true; + } + + protected assertSetup() { + // Skip setup check if TTL is not enabled + if (this.ttl == null) return; + + if (!this.isSetup) { + throw new Error( + "MongoDBSaver is not initialized. Please call `MongoDBSaver.setup()` first before using the checkpointer." + ); + } + } + constructor( { client, dbName, checkpointCollectionName, checkpointWritesCollectionName, + ttl, }: MongoDBSaverParams, serde?: SerializerProtocol ) { super(serde); this.client = client; + this.ttl = ttl; this.db = this.client.db(dbName); + this.isSetup = false; + this.checkpointCollectionName = checkpointCollectionName ?? this.checkpointCollectionName; this.checkpointWritesCollectionName = @@ -55,6 +91,8 @@ export class MongoDBSaver extends BaseCheckpointSaver { * for the given thread ID is retrieved. */ async getTuple(config: RunnableConfig): Promise { + this.assertSetup(); + const { thread_id, checkpoint_ns = "", @@ -135,6 +173,8 @@ export class MongoDBSaver extends BaseCheckpointSaver { config: RunnableConfig, options?: CheckpointListOptions ): AsyncGenerator { + this.assertSetup(); + const { limit, before, filter } = options ?? {}; const query: Record = {}; @@ -210,6 +250,8 @@ export class MongoDBSaver extends BaseCheckpointSaver { checkpoint: Checkpoint, metadata: CheckpointMetadata ): Promise { + this.assertSetup(); + const thread_id = config.configurable?.thread_id; const checkpoint_ns = config.configurable?.checkpoint_ns ?? ""; const checkpoint_id = checkpoint.id; @@ -234,6 +276,7 @@ export class MongoDBSaver extends BaseCheckpointSaver { type: checkpointType, checkpoint: serializedCheckpoint, metadata: serializedMetadata, + ...(this.ttl ? { _createdAtForTTL: new Date() } : {}), }; const upsertQuery = { thread_id, @@ -261,6 +304,8 @@ export class MongoDBSaver extends BaseCheckpointSaver { writes: PendingWrite[], taskId: string ): Promise { + this.assertSetup(); + const thread_id = config.configurable?.thread_id; const checkpoint_ns = config.configurable?.checkpoint_ns; const checkpoint_id = config.configurable?.checkpoint_id; @@ -289,7 +334,14 @@ export class MongoDBSaver extends BaseCheckpointSaver { return { updateOne: { filter: upsertQuery, - update: { $set: { channel, type, value: serializedValue } }, + update: { + $set: { + channel, + type, + value: serializedValue, + ...(this.ttl ? { _createdAtForTTL: new Date() } : {}), + }, + }, upsert: true, }, }; diff --git a/libs/checkpoint-mongodb/src/tests/checkpoints.int.test.ts b/libs/checkpoint-mongodb/src/tests/checkpoints.int.test.ts index 523bf9847..3793f2c61 100644 --- a/libs/checkpoint-mongodb/src/tests/checkpoints.int.test.ts +++ b/libs/checkpoint-mongodb/src/tests/checkpoints.int.test.ts @@ -1,4 +1,4 @@ -import { describe, it, expect, afterAll } from "vitest"; +import { describe, it, expect, afterEach, afterAll } from "vitest"; import { MongoClient } from "mongodb"; import { Checkpoint, @@ -47,128 +47,149 @@ const client = new MongoClient(getEnvironmentVariable("MONGODB_URL")!, { auth: { username: "user", password: "password" }, }); -afterAll(async () => { +afterEach(async () => { const db = client.db(); await db.dropCollection("checkpoints"); await db.dropCollection("checkpoint_writes"); +}); + +afterAll(async () => { await client.close(); }); describe("MongoDBSaver", () => { - it("should save and retrieve checkpoints correctly", async () => { - const saver = new MongoDBSaver({ client }); - - // get undefined checkpoint - const undefinedCheckpoint = await saver.getTuple({ - configurable: { thread_id: "1" }, - }); - expect(undefinedCheckpoint).toBeUndefined(); - - // save first checkpoint - const runnableConfig = await saver.put( - { configurable: { thread_id: "1" } }, - checkpoint1, - { source: "update", step: -1, parents: {} } - ); - expect(runnableConfig).toEqual({ - configurable: { - thread_id: "1", - checkpoint_ns: "", - checkpoint_id: checkpoint1.id, - }, - }); - - // add some writes - await saver.putWrites( - { + it.each([{ ttl: undefined }, { ttl: { expireAfterSeconds: 60 * 60 } }])( + "should save and retrieve checkpoints correctly (%s)", + async ({ ttl }) => { + const saver = new MongoDBSaver({ client, ttl }); + await saver.setup(); + + const threadId = crypto.randomUUID(); + + // get undefined checkpoint + const undefinedCheckpoint = await saver.getTuple({ + configurable: { thread_id: threadId }, + }); + expect(undefinedCheckpoint).toBeUndefined(); + + // save first checkpoint + const runnableConfig = await saver.put( + { configurable: { thread_id: threadId } }, + checkpoint1, + { source: "update", step: -1, parents: {} } + ); + expect(runnableConfig).toEqual({ configurable: { + thread_id: threadId, + checkpoint_ns: "", checkpoint_id: checkpoint1.id, + }, + }); + + // add some writes + await saver.putWrites( + { + configurable: { + checkpoint_id: checkpoint1.id, + checkpoint_ns: "", + thread_id: threadId, + }, + }, + [["bar", "baz"]], + "foo" + ); + + // get first checkpoint tuple + const firstCheckpointTuple = await saver.getTuple({ + configurable: { thread_id: threadId }, + }); + expect(firstCheckpointTuple?.config).toEqual({ + configurable: { + thread_id: threadId, checkpoint_ns: "", - thread_id: "1", + checkpoint_id: checkpoint1.id, }, - }, - [["bar", "baz"]], - "foo" - ); - - // get first checkpoint tuple - const firstCheckpointTuple = await saver.getTuple({ - configurable: { thread_id: "1" }, - }); - expect(firstCheckpointTuple?.config).toEqual({ - configurable: { - thread_id: "1", - checkpoint_ns: "", - checkpoint_id: checkpoint1.id, - }, - }); - expect(firstCheckpointTuple?.checkpoint).toEqual(checkpoint1); - expect(firstCheckpointTuple?.parentConfig).toBeUndefined(); - expect(firstCheckpointTuple?.pendingWrites).toEqual([ - ["foo", "bar", "baz"], - ]); - - // save second checkpoint - await saver.put( - { + }); + expect(firstCheckpointTuple?.checkpoint).toEqual(checkpoint1); + expect(firstCheckpointTuple?.parentConfig).toBeUndefined(); + expect(firstCheckpointTuple?.pendingWrites).toEqual([ + ["foo", "bar", "baz"], + ]); + + // save second checkpoint + await saver.put( + { + configurable: { + thread_id: threadId, + checkpoint_id: "2024-04-18T17:19:07.952Z", + }, + }, + checkpoint2, + { source: "update", step: -1, parents: {} } + ); + + // verify that parentTs is set and retrieved correctly for second checkpoint + const secondCheckpointTuple = await saver.getTuple({ + configurable: { thread_id: threadId }, + }); + expect(secondCheckpointTuple?.parentConfig).toEqual({ configurable: { - thread_id: "1", + thread_id: threadId, + checkpoint_ns: "", checkpoint_id: "2024-04-18T17:19:07.952Z", }, - }, - checkpoint2, - { source: "update", step: -1, parents: {} } - ); - - // verify that parentTs is set and retrieved correctly for second checkpoint - const secondCheckpointTuple = await saver.getTuple({ - configurable: { thread_id: "1" }, - }); - expect(secondCheckpointTuple?.parentConfig).toEqual({ - configurable: { - thread_id: "1", - checkpoint_ns: "", - checkpoint_id: "2024-04-18T17:19:07.952Z", - }, - }); - - // list checkpoints - const checkpointTupleGenerator = saver.list({ - configurable: { thread_id: "1" }, - }); - const checkpointTuples: CheckpointTuple[] = []; - for await (const checkpoint of checkpointTupleGenerator) { - checkpointTuples.push(checkpoint); + }); + + // list checkpoints + const checkpointTupleGenerator = saver.list({ + configurable: { thread_id: threadId }, + }); + const checkpointTuples: CheckpointTuple[] = []; + for await (const checkpoint of checkpointTupleGenerator) { + checkpointTuples.push(checkpoint); + } + expect(checkpointTuples.length).toBe(2); + + const checkpointTuple1 = checkpointTuples[0]; + const checkpointTuple2 = checkpointTuples[1]; + expect(checkpointTuple1.checkpoint.ts).toBe("2024-04-20T17:19:07.952Z"); + expect(checkpointTuple2.checkpoint.ts).toBe("2024-04-19T17:19:07.952Z"); } - expect(checkpointTuples.length).toBe(2); - - const checkpointTuple1 = checkpointTuples[0]; - const checkpointTuple2 = checkpointTuples[1]; - expect(checkpointTuple1.checkpoint.ts).toBe("2024-04-20T17:19:07.952Z"); - expect(checkpointTuple2.checkpoint.ts).toBe("2024-04-19T17:19:07.952Z"); - }); + ); it("should delete thread", async () => { + const threadId1 = crypto.randomUUID(); + const threadId2 = crypto.randomUUID(); + const saver = new MongoDBSaver({ client }); - await saver.put({ configurable: { thread_id: "1" } }, emptyCheckpoint(), { - source: "update", - step: -1, - parents: {}, - }); - await saver.put({ configurable: { thread_id: "2" } }, emptyCheckpoint(), { - source: "update", - step: -1, - parents: {}, - }); + await saver.put( + { configurable: { thread_id: threadId1 } }, + emptyCheckpoint(), + { + source: "update", + step: -1, + parents: {}, + } + ); + + await saver.put( + { configurable: { thread_id: threadId2 } }, + emptyCheckpoint(), + { + source: "update", + step: -1, + parents: {}, + } + ); - await saver.deleteThread("1"); + await saver.deleteThread(threadId1); expect( - await saver.getTuple({ configurable: { thread_id: "1" } }) + await saver.getTuple({ configurable: { thread_id: threadId1 } }) ).toBeUndefined(); expect( - await saver.getTuple({ configurable: { thread_id: "2" } }) + await saver.getTuple({ configurable: { thread_id: threadId2 } }) ).toBeDefined(); }); });