Skip to content

Commit 1bd7466

Browse files
committed
feat: Qdrant support
Signed-off-by: Anush008 <[email protected]>
1 parent 69dea74 commit 1bd7466

File tree

4 files changed

+269
-4
lines changed

4 files changed

+269
-4
lines changed

adala/memories/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .file_memory import FileMemory
22
from .vectordb import VectorDBMemory
3+
from .qdrant_memory import QdrantMemory
34
from .base import Memory

adala/memories/qdrant_memory.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from typing import Any, List, Dict, Optional
2+
import uuid
3+
from pydantic import Field, model_validator
4+
5+
from .base import Memory
6+
7+
try:
8+
from qdrant_client import QdrantClient
9+
from qdrant_client.models import Distance, VectorParams, PointStruct
10+
import openai
11+
12+
QDRANT_AVAILABLE = True
13+
except ImportError:
14+
QDRANT_AVAILABLE = False
15+
16+
17+
class QdrantMemory(Memory):
18+
"""
19+
Memory backed by [Qdrant](https://qdrant.tech/).
20+
"""
21+
22+
model_config = {"arbitrary_types_allowed": True}
23+
24+
collection_name: str = Field(..., description="Name of the Qdrant collection")
25+
openai_api_key: str = Field(..., description="OpenAI API key for embeddings")
26+
openai_embedding_model: str = Field(
27+
default="text-embedding-3-small", description="OpenAI embedding model"
28+
)
29+
qdrant_url: Optional[str] = Field(
30+
default=None, description="Qdrant server URL"
31+
)
32+
qdrant_api_key: Optional[str] = Field(
33+
default=None, description="Qdrant API key for remote instances"
34+
)
35+
qdrant_client: Optional[QdrantClient] = Field(
36+
default=None, description="Pre-configured QdrantClient instance"
37+
)
38+
dimension: int = Field(default=1536, description="Vector dimension size")
39+
distance_metric: str = Field(
40+
default="Cosine", description="Distance metric for similarity search"
41+
)
42+
43+
_client: Optional[QdrantClient] = None
44+
_openai_client: Optional[openai.OpenAI] = None
45+
46+
@model_validator(mode="after")
47+
def init_database(self):
48+
if not QDRANT_AVAILABLE:
49+
raise ImportError(
50+
"Qdrant dependencies not available. "
51+
"Please install with: pip install qdrant-client openai"
52+
)
53+
54+
if self.qdrant_client is not None and (
55+
self.qdrant_url is not None or self.qdrant_api_key is not None
56+
):
57+
raise ValueError(
58+
"Cannot specify both 'qdrant_client' and 'qdrant_url'/'qdrant_api_key'. "
59+
"Use either a pre-configured QdrantClient or URL-based configuration, not both."
60+
)
61+
62+
if self.qdrant_client is not None:
63+
self._client = self.qdrant_client
64+
elif self.qdrant_url:
65+
self._client = QdrantClient(
66+
url=self.qdrant_url, api_key=self.qdrant_api_key
67+
)
68+
else:
69+
raise ValueError(
70+
"No Qdrant configuration provided. Please specify either 'qdrant_client' "
71+
"or 'qdrant_url' to configure the Qdrant connection."
72+
)
73+
74+
if not self.openai_api_key:
75+
raise ValueError("OpenAI API key is required but not provided")
76+
self._openai_client = openai.OpenAI(api_key=self.openai_api_key)
77+
78+
if not self._client.collection_exists(self.collection_name):
79+
self._client.create_collection(
80+
collection_name=self.collection_name,
81+
vectors_config=VectorParams(
82+
size=self.dimension, distance=self._get_distance_metric()
83+
),
84+
)
85+
86+
return self
87+
88+
def _generate_uuid(self, string: str) -> str:
89+
return uuid.uuid5(uuid.NAMESPACE_URL, string).hex
90+
91+
def _get_distance_metric(self) -> Distance:
92+
distance_map = {
93+
"Cosine": Distance.COSINE,
94+
"Dot": Distance.DOT,
95+
"Euclidean": Distance.EUCLID,
96+
"Manhattan": Distance.MANHATTAN,
97+
}
98+
return distance_map.get(self.distance_metric, Distance.COSINE)
99+
100+
def _get_embedding(self, text: str) -> List[float]:
101+
response = self._openai_client.embeddings.create(
102+
model=self.openai_embedding_model, input=text
103+
)
104+
return response.data[0].embedding
105+
106+
def _get_embeddings(self, texts: List[str]) -> List[List[float]]:
107+
response = self._openai_client.embeddings.create(
108+
model=self.openai_embedding_model, input=texts
109+
)
110+
return [data.embedding for data in response.data]
111+
112+
def remember(self, observation: str, data: Any):
113+
"""Store a single observation with its associated data."""
114+
self.remember_many([observation], [data])
115+
116+
def remember_many(self, observations: List[str], data: List[Dict]):
117+
"""Store multiple observations with their associated data."""
118+
119+
data = [{k: v for k, v in d.items() if v is not None} for d in data]
120+
121+
embeddings = self._get_embeddings(observations)
122+
123+
points = []
124+
for obs, embedding, metadata in zip(observations, embeddings, data):
125+
point_id = self._generate_uuid(obs)
126+
points.append(
127+
PointStruct(
128+
id=point_id, vector=embedding, payload={"text": obs, **metadata}
129+
)
130+
)
131+
132+
self._client.upsert(collection_name=self.collection_name, points=points)
133+
134+
def retrieve_many(self, observations: List[str], num_results: int = 1) -> List[Any]:
135+
"""Retrieve similar observations for multiple queries."""
136+
results = []
137+
138+
for observation in observations:
139+
query_embedding = self._get_embedding(observation)
140+
141+
search_results = self._client.query_points(
142+
collection_name=self.collection_name,
143+
query=query_embedding,
144+
limit=num_results,
145+
with_payload=True,
146+
).points
147+
148+
metadatas = []
149+
for result in search_results:
150+
payload = result.payload.copy()
151+
152+
payload.pop("text", None)
153+
metadatas.append(payload)
154+
155+
results.append(metadatas)
156+
157+
return results
158+
159+
def retrieve(self, observation: str, num_results: int = 1) -> Any:
160+
"""Retrieve similar observations for a single query."""
161+
return self.retrieve_many([observation], num_results=num_results)[0]
162+
163+
def clear(self):
164+
"""Clear all data from the collection."""
165+
166+
if self._client.collection_exists(self.collection_name):
167+
self._client.delete_collection(self.collection_name)
168+
169+
self._client.create_collection(
170+
collection_name=self.collection_name,
171+
vectors_config=VectorParams(
172+
size=self.dimension, distance=self._get_distance_metric()
173+
),
174+
)

poetry.lock

Lines changed: 92 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ dependencies = [
4646
"pandarallel (>=1.6.5,<2.0.0)",
4747
"instructor (==1.4.3)",
4848
"async-lru (>=2.0.5,<3.0.0)",
49-
"jinja2 (>=3.1.6,<4.0)"
49+
"jinja2 (>=3.1.6,<4.0)",
50+
"qdrant-client (>=1.15.1,<2.0.0)"
5051
]
5152

5253
[project.urls]

0 commit comments

Comments
 (0)