1+ import multiprocessing as mp
2+ from typing import List , Tuple
3+
4+ from cassandra .cluster import Cluster , ExecutionProfile , EXEC_PROFILE_DEFAULT
5+ from cassandra .policies import DCAwareRoundRobinPolicy , TokenAwarePolicy , ExponentialReconnectionPolicy
6+ from cassandra import ConsistencyLevel , ProtocolVersion
7+
8+ from dataset_reader .base_reader import Query
9+ from engine .base_client .distances import Distance
10+ from engine .base_client .search import BaseSearcher
11+ from engine .clients .cassandra .config import CASSANDRA_KEYSPACE , CASSANDRA_TABLE
12+ from engine .clients .cassandra .parser import CassandraConditionParser
13+
14+
15+ class CassandraSearcher (BaseSearcher ):
16+ search_params = {}
17+ session = None
18+ cluster = None
19+ parser = CassandraConditionParser ()
20+
21+ @classmethod
22+ def init_client (cls , host , distance , connection_params : dict , search_params : dict ):
23+ # Set up execution profiles for consistency and performance
24+ profile = ExecutionProfile (
25+ load_balancing_policy = TokenAwarePolicy (DCAwareRoundRobinPolicy ()),
26+ consistency_level = ConsistencyLevel .LOCAL_ONE , # Use LOCAL_ONE for faster reads
27+ request_timeout = 60
28+ )
29+
30+ # Initialize Cassandra cluster connection
31+ cls .cluster = Cluster (
32+ contact_points = [host ],
33+ execution_profiles = {EXEC_PROFILE_DEFAULT : profile },
34+ reconnection_policy = ExponentialReconnectionPolicy (base_delay = 1 , max_delay = 60 ),
35+ protocol_version = ProtocolVersion .V4 ,
36+ ** connection_params
37+ )
38+ cls .session = cls .cluster .connect (CASSANDRA_KEYSPACE )
39+ cls .search_params = search_params
40+
41+ # Update prepared statements with current search parameters
42+ cls .update_prepared_statements (distance )
43+
44+ @classmethod
45+ def get_mp_start_method (cls ):
46+ return "fork" if "fork" in mp .get_all_start_methods () else "spawn"
47+
48+ @classmethod
49+ def update_prepared_statements (cls , distance ):
50+ """Create prepared statements for vector searches"""
51+ # Prepare a vector similarity search query
52+ limit = cls .search_params .get ("top" , 10 )
53+
54+ if distance == Distance .COSINE :
55+ SIMILARITY_FUNC = "similarity_cosine"
56+ elif distance == Distance .L2 :
57+ SIMILARITY_FUNC = "similarity_euclidean"
58+ elif distance == Distance .DOT :
59+ SIMILARITY_FUNC = "similarity_dot_product"
60+ else :
61+ raise ValueError (f"Unsupported distance metric: { distance } " )
62+
63+ cls .ann_search_stmt = cls .session .prepare (
64+ f"""SELECT id, { SIMILARITY_FUNC } (embedding, ?) as distance
65+ FROM { CASSANDRA_TABLE }
66+ ORDER BY embedding ANN OF ?
67+ LIMIT { limit } """
68+ )
69+
70+ # Prepare a statement for filtered vector search
71+ cls .filtered_search_query_template = (
72+ f"""SELECT id, { SIMILARITY_FUNC } (embedding, ?) as distance
73+ FROM { CASSANDRA_TABLE }
74+ WHERE {{conditions}}
75+ ORDER BY embedding ANN OF ?
76+ LIMIT { limit } """
77+ )
78+
79+ @classmethod
80+ def search_one (cls , query : Query , top : int ) -> List [Tuple [int , float ]]:
81+ """Execute a vector similarity search with optional filters"""
82+ # Convert query vector to a format Cassandra can use
83+ query_vector = query .vector .tolist () if hasattr (query .vector , 'tolist' ) else query .vector
84+
85+ # Generate filter conditions if metadata conditions exist
86+ filter_conditions = cls .parser .parse (query .meta_conditions )
87+
88+ try :
89+ if filter_conditions :
90+ # Use the filtered search query
91+ query_with_conditions = cls .filtered_search_query_template .format (conditions = filter_conditions )
92+ results = cls .session .execute (
93+ cls .session .prepare (query_with_conditions ),
94+ (query_vector , query_vector )
95+ )
96+ else :
97+ # Use the basic ANN search query
98+ results = cls .session .execute (
99+ cls .ann_search_stmt ,
100+ (query_vector , query_vector )
101+ )
102+
103+ # Extract and return results
104+ return [(row .id , row .distance ) for row in results ]
105+
106+ except Exception as ex :
107+ print (f"Error during Cassandra vector search: { ex } " )
108+ raise ex
109+
110+ @classmethod
111+ def delete_client (cls ):
112+ """Close the Cassandra connection"""
113+ if cls .session :
114+ cls .session .shutdown ()
115+ if cls .cluster :
116+ cls .cluster .shutdown ()
0 commit comments