Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion libs/checkpoint-mongodb/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ export type MongoDBSaverParams = {
dbName?: string;
checkpointCollectionName?: string;
checkpointWritesCollectionName?: string;
ttl?: { expireAfterSeconds: number };
};

/**
Expand All @@ -26,22 +27,57 @@ export class MongoDBSaver extends BaseCheckpointSaver {

protected db: MongoDatabase;

protected ttl: { expireAfterSeconds: number } | undefined;

protected isSetup: boolean;

checkpointCollectionName = "checkpoints";

checkpointWritesCollectionName = "checkpoint_writes";

async setup(): Promise<void> {
if (this.ttl != null) {
const { expireAfterSeconds } = this.ttl;
await Promise.all([
this.db
.collection(this.checkpointCollectionName)
.createIndex({ _createdAtForTTL: 1 }, { expireAfterSeconds }),
this.db
.collection(this.checkpointWritesCollectionName)
.createIndex({ _createdAtForTTL: 1 }, { expireAfterSeconds }),
]);
}

this.isSetup = true;
}

protected assertSetup() {
// Skip setup check if TTL is not enabled
if (this.ttl == null) return;

if (!this.isSetup) {
throw new Error(
"MongoDBSaver is not initialized. Please call `MongoDBSaver.setup()` first before using the checkpointer."
);
}
}

constructor(
{
client,
dbName,
checkpointCollectionName,
checkpointWritesCollectionName,
ttl,
}: MongoDBSaverParams,
serde?: SerializerProtocol
) {
super(serde);
this.client = client;
this.ttl = ttl;
this.db = this.client.db(dbName);
this.isSetup = false;

this.checkpointCollectionName =
checkpointCollectionName ?? this.checkpointCollectionName;
this.checkpointWritesCollectionName =
Expand All @@ -55,6 +91,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
* for the given thread ID is retrieved.
*/
async getTuple(config: RunnableConfig): Promise<CheckpointTuple | undefined> {
this.assertSetup();

const {
thread_id,
checkpoint_ns = "",
Expand Down Expand Up @@ -135,6 +173,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
config: RunnableConfig,
options?: CheckpointListOptions
): AsyncGenerator<CheckpointTuple> {
this.assertSetup();

const { limit, before, filter } = options ?? {};
const query: Record<string, unknown> = {};

Expand Down Expand Up @@ -210,6 +250,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
checkpoint: Checkpoint,
metadata: CheckpointMetadata
): Promise<RunnableConfig> {
this.assertSetup();

const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns ?? "";
const checkpoint_id = checkpoint.id;
Expand All @@ -234,6 +276,7 @@ export class MongoDBSaver extends BaseCheckpointSaver {
type: checkpointType,
checkpoint: serializedCheckpoint,
metadata: serializedMetadata,
...(this.ttl ? { _createdAtForTTL: new Date() } : {}),
};
const upsertQuery = {
thread_id,
Expand Down Expand Up @@ -261,6 +304,8 @@ export class MongoDBSaver extends BaseCheckpointSaver {
writes: PendingWrite[],
taskId: string
): Promise<void> {
this.assertSetup();

const thread_id = config.configurable?.thread_id;
const checkpoint_ns = config.configurable?.checkpoint_ns;
const checkpoint_id = config.configurable?.checkpoint_id;
Expand Down Expand Up @@ -289,7 +334,14 @@ export class MongoDBSaver extends BaseCheckpointSaver {
return {
updateOne: {
filter: upsertQuery,
update: { $set: { channel, type, value: serializedValue } },
update: {
$set: {
channel,
type,
value: serializedValue,
...(this.ttl ? { _createdAtForTTL: new Date() } : {}),
},
},
upsert: true,
},
};
Expand Down
219 changes: 120 additions & 99 deletions libs/checkpoint-mongodb/src/tests/checkpoints.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { describe, it, expect, afterAll } from "vitest";
import { describe, it, expect, afterEach, afterAll } from "vitest";
import { MongoClient } from "mongodb";
import {
Checkpoint,
Expand Down Expand Up @@ -47,128 +47,149 @@ const client = new MongoClient(getEnvironmentVariable("MONGODB_URL")!, {
auth: { username: "user", password: "password" },
});

afterAll(async () => {
afterEach(async () => {
const db = client.db();
await db.dropCollection("checkpoints");
await db.dropCollection("checkpoint_writes");
});

afterAll(async () => {
await client.close();
});

describe("MongoDBSaver", () => {
it("should save and retrieve checkpoints correctly", async () => {
const saver = new MongoDBSaver({ client });

// get undefined checkpoint
const undefinedCheckpoint = await saver.getTuple({
configurable: { thread_id: "1" },
});
expect(undefinedCheckpoint).toBeUndefined();

// save first checkpoint
const runnableConfig = await saver.put(
{ configurable: { thread_id: "1" } },
checkpoint1,
{ source: "update", step: -1, parents: {} }
);
expect(runnableConfig).toEqual({
configurable: {
thread_id: "1",
checkpoint_ns: "",
checkpoint_id: checkpoint1.id,
},
});

// add some writes
await saver.putWrites(
{
it.each([{ ttl: undefined }, { ttl: { expireAfterSeconds: 60 * 60 } }])(
"should save and retrieve checkpoints correctly (%s)",
async ({ ttl }) => {
const saver = new MongoDBSaver({ client, ttl });
await saver.setup();

const threadId = crypto.randomUUID();

// get undefined checkpoint
const undefinedCheckpoint = await saver.getTuple({
configurable: { thread_id: threadId },
});
expect(undefinedCheckpoint).toBeUndefined();

// save first checkpoint
const runnableConfig = await saver.put(
{ configurable: { thread_id: threadId } },
checkpoint1,
{ source: "update", step: -1, parents: {} }
);
expect(runnableConfig).toEqual({
configurable: {
thread_id: threadId,
checkpoint_ns: "",
checkpoint_id: checkpoint1.id,
},
});

// add some writes
await saver.putWrites(
{
configurable: {
checkpoint_id: checkpoint1.id,
checkpoint_ns: "",
thread_id: threadId,
},
},
[["bar", "baz"]],
"foo"
);

// get first checkpoint tuple
const firstCheckpointTuple = await saver.getTuple({
configurable: { thread_id: threadId },
});
expect(firstCheckpointTuple?.config).toEqual({
configurable: {
thread_id: threadId,
checkpoint_ns: "",
thread_id: "1",
checkpoint_id: checkpoint1.id,
},
},
[["bar", "baz"]],
"foo"
);

// get first checkpoint tuple
const firstCheckpointTuple = await saver.getTuple({
configurable: { thread_id: "1" },
});
expect(firstCheckpointTuple?.config).toEqual({
configurable: {
thread_id: "1",
checkpoint_ns: "",
checkpoint_id: checkpoint1.id,
},
});
expect(firstCheckpointTuple?.checkpoint).toEqual(checkpoint1);
expect(firstCheckpointTuple?.parentConfig).toBeUndefined();
expect(firstCheckpointTuple?.pendingWrites).toEqual([
["foo", "bar", "baz"],
]);

// save second checkpoint
await saver.put(
{
});
expect(firstCheckpointTuple?.checkpoint).toEqual(checkpoint1);
expect(firstCheckpointTuple?.parentConfig).toBeUndefined();
expect(firstCheckpointTuple?.pendingWrites).toEqual([
["foo", "bar", "baz"],
]);

// save second checkpoint
await saver.put(
{
configurable: {
thread_id: threadId,
checkpoint_id: "2024-04-18T17:19:07.952Z",
},
},
checkpoint2,
{ source: "update", step: -1, parents: {} }
);

// verify that parentTs is set and retrieved correctly for second checkpoint
const secondCheckpointTuple = await saver.getTuple({
configurable: { thread_id: threadId },
});
expect(secondCheckpointTuple?.parentConfig).toEqual({
configurable: {
thread_id: "1",
thread_id: threadId,
checkpoint_ns: "",
checkpoint_id: "2024-04-18T17:19:07.952Z",
},
},
checkpoint2,
{ source: "update", step: -1, parents: {} }
);

// verify that parentTs is set and retrieved correctly for second checkpoint
const secondCheckpointTuple = await saver.getTuple({
configurable: { thread_id: "1" },
});
expect(secondCheckpointTuple?.parentConfig).toEqual({
configurable: {
thread_id: "1",
checkpoint_ns: "",
checkpoint_id: "2024-04-18T17:19:07.952Z",
},
});

// list checkpoints
const checkpointTupleGenerator = saver.list({
configurable: { thread_id: "1" },
});
const checkpointTuples: CheckpointTuple[] = [];
for await (const checkpoint of checkpointTupleGenerator) {
checkpointTuples.push(checkpoint);
});

// list checkpoints
const checkpointTupleGenerator = saver.list({
configurable: { thread_id: threadId },
});
const checkpointTuples: CheckpointTuple[] = [];
for await (const checkpoint of checkpointTupleGenerator) {
checkpointTuples.push(checkpoint);
}
expect(checkpointTuples.length).toBe(2);

const checkpointTuple1 = checkpointTuples[0];
const checkpointTuple2 = checkpointTuples[1];
expect(checkpointTuple1.checkpoint.ts).toBe("2024-04-20T17:19:07.952Z");
expect(checkpointTuple2.checkpoint.ts).toBe("2024-04-19T17:19:07.952Z");
}
expect(checkpointTuples.length).toBe(2);

const checkpointTuple1 = checkpointTuples[0];
const checkpointTuple2 = checkpointTuples[1];
expect(checkpointTuple1.checkpoint.ts).toBe("2024-04-20T17:19:07.952Z");
expect(checkpointTuple2.checkpoint.ts).toBe("2024-04-19T17:19:07.952Z");
});
);

it("should delete thread", async () => {
const threadId1 = crypto.randomUUID();
const threadId2 = crypto.randomUUID();

const saver = new MongoDBSaver({ client });
await saver.put({ configurable: { thread_id: "1" } }, emptyCheckpoint(), {
source: "update",
step: -1,
parents: {},
});

await saver.put({ configurable: { thread_id: "2" } }, emptyCheckpoint(), {
source: "update",
step: -1,
parents: {},
});
await saver.put(
{ configurable: { thread_id: threadId1 } },
emptyCheckpoint(),
{
source: "update",
step: -1,
parents: {},
}
);

await saver.put(
{ configurable: { thread_id: threadId2 } },
emptyCheckpoint(),
{
source: "update",
step: -1,
parents: {},
}
);

await saver.deleteThread("1");
await saver.deleteThread(threadId1);

expect(
await saver.getTuple({ configurable: { thread_id: "1" } })
await saver.getTuple({ configurable: { thread_id: threadId1 } })
).toBeUndefined();
expect(
await saver.getTuple({ configurable: { thread_id: "2" } })
await saver.getTuple({ configurable: { thread_id: threadId2 } })
).toBeDefined();
});
});
Loading