from __future__ import annotations
from datetime import datetime, timezone
from typing import Any, Iterator, AsyncIterator
from bson import Binary
from langchain_core.runnables import RunnableConfig
from langgraph.checkpoint.base import (
BaseCheckpointSaver,
Checkpoint,
CheckpointMetadata,
CheckpointTuple,
ChannelVersions,
)
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import MongoClient
class MongoDBCheckpointSaver(BaseCheckpointSaver):
def __init__(
self,
uri: str = "mongodb://localhost:27017/",
db_name: str = "langgraph",
collection_name: str = "checkpoints"
):
super().__init__()
self.client = MongoClient(uri)
self.db = self.client[db_name]
self.collection = self.db[collection_name]
self.async_client = AsyncIOMotorClient(uri)
self.async_db = self.async_client[db_name]
self.async_collection = self.async_db[collection_name]
self.collection.create_index("thread_id")
self.collection.create_index([("thread_id", 1), ("checkpoint_id", 1)], unique=True)
def _serialize(self, obj: Any) -> dict:
typ, data = self.serde.dumps_typed(obj)
return {"type": typ, "data": Binary(data)}
def _deserialize(self, data: dict) -> Any:
return self.serde.loads_typed((data["type"], data["data"]))
def get_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
thread_id = config["configurable"]["thread_id"]
doc = self.collection.find_one({"thread_id": thread_id}, sort=[("checkpoint_id", -1)])
if not doc:
return None
return CheckpointTuple(
config=config,
checkpoint=self._deserialize(doc["checkpoint"]),
metadata=self._deserialize(doc["metadata"]),
parent_config=None,
)
def list(
self,
config: RunnableConfig | None,
**kwargs: Any
) -> Iterator[CheckpointTuple]:
query = {}
if config and "configurable" in config:
thread_id = config["configurable"].get("thread_id")
if thread_id:
query["thread_id"] = thread_id
for doc in self.collection.find(query, sort=[("checkpoint_id", -1)]):
yield CheckpointTuple(
config={"configurable": {"thread_id": doc["thread_id"]}},
checkpoint=self._deserialize(doc["checkpoint"]),
metadata=self._deserialize(doc["metadata"]),
)
def put(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
thread_id = config["configurable"]["thread_id"]
checkpoint_id = checkpoint["id"]
self.collection.update_one(
{"thread_id": thread_id, "checkpoint_id": checkpoint_id},
{
"$set": {
"thread_id": thread_id,
"checkpoint_id": checkpoint_id,
"checkpoint": self._serialize(checkpoint),
"metadata": self._serialize(metadata),
"updated_at": datetime.now(timezone.utc),
}
},
upsert=True,
)
return config
def put_writes(self, *args, **kwargs):
pass
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
thread_id = config["configurable"]["thread_id"]
doc = await self.async_collection.find_one(
{"thread_id": thread_id}, sort=[("checkpoint_id", -1)]
)
if not doc:
return None
return CheckpointTuple(
config=config,
checkpoint=self._deserialize(doc["checkpoint"]),
metadata=self._deserialize(doc["metadata"]),
)
async def alist(
self, config: RunnableConfig | None, **kwargs: Any
) -> AsyncIterator[CheckpointTuple]:
query = {}
if config and "configurable" in config:
thread_id = config["configurable"].get("thread_id")
if thread_id:
query["thread_id"] = thread_id
cursor = self.async_collection.find(query).sort("checkpoint_id", -1)
async for doc in cursor:
yield CheckpointTuple(
config={"configurable": {"thread_id": doc["thread_id"]}},
checkpoint=self._deserialize(doc["checkpoint"]),
metadata=self._deserialize(doc["metadata"]),
)
async def aput(
self,
config: RunnableConfig,
checkpoint: Checkpoint,
metadata: CheckpointMetadata,
new_versions: ChannelVersions,
) -> RunnableConfig:
thread_id = config["configurable"]["thread_id"]
checkpoint_id = checkpoint["id"]
await self.async_collection.update_one(
{"thread_id": thread_id, "checkpoint_id": checkpoint_id},
{
"$set": {
"thread_id": thread_id,
"checkpoint_id": checkpoint_id,
"checkpoint": self._serialize(checkpoint),
"metadata": self._serialize(metadata),
"updated_at": datetime.now(timezone.utc),
}
},
upsert=True,
)
return config
async def aput_writes(self, *args, **kwargs):
pass