2
2
3
3
from tqdm import tqdm
4
4
import weaviate
5
+ import json
5
6
6
7
from vdf_io .export_vdf .vdb_export_cls import ExportVDB
8
+ from vdf_io .meta_types import NamespaceMeta
7
9
from vdf_io .names import DBNames
8
10
from vdf_io .util import set_arg_from_input , set_arg_from_password
11
+ from typing import Dict , List
9
12
10
13
# Set these environment variables
11
14
URL = os .getenv ("YOUR_WCS_URL" )
12
15
APIKEY = os .getenv ("YOUR_WCS_API_KEY" )
16
+ OPENAI_APIKEY = os .getenv ("OPENAI_APIKEY" )
13
17
14
18
15
19
class ExportWeaviate (ExportVDB ):
@@ -23,6 +27,15 @@ def make_parser(cls, subparsers):
23
27
24
28
parser_weaviate .add_argument ("--url" , type = str , help = "URL of Weaviate instance" )
25
29
parser_weaviate .add_argument ("--api_key" , type = str , help = "Weaviate API key" )
30
+ parser_weaviate .add_argument ("--openai_api_key" , type = str , help = "Openai API key" )
31
+ parser_weaviate .add_arguments (
32
+ "--batch_size" , type = int , help = "batch size for fetching" ,
33
+ default = 1000
34
+ )
35
+ parser_weaviate .add_argument (
36
+ "--connection-type" , type = str , choices = ["local" , "cloud" ], default = "cloud" ,
37
+ help = "Type of connection to Weaviate (local or cloud)"
38
+ )
26
39
parser_weaviate .add_argument (
27
40
"--classes" , type = str , help = "Classes to export (comma-separated)"
28
41
)
@@ -35,6 +48,12 @@ def export_vdb(cls, args):
35
48
"Enter the URL of Weaviate instance: " ,
36
49
str ,
37
50
)
51
+ set_arg_from_input (
52
+ args ,
53
+ "connection_type" ,
54
+ "Enter 'local' or 'cloud' for connection types: " ,
55
+ choices = ['local' , 'cloud' ],
56
+ )
38
57
set_arg_from_password (
39
58
args ,
40
59
"api_key" ,
@@ -55,14 +74,20 @@ def export_vdb(cls, args):
55
74
weaviate_export .get_data ()
56
75
return weaviate_export
57
76
58
- # Connect to a WCS instance
77
+ # Connect to a WCS or local instance
59
78
def __init__ (self , args ):
60
79
super ().__init__ (args )
61
- self .client = weaviate .connect_to_wcs (
62
- cluster_url = self .args ["url" ],
63
- auth_credentials = weaviate .auth .AuthApiKey (self .args ["api_key" ]),
64
- skip_init_checks = True ,
65
- )
80
+ if self .args ["connection_type" ] == "local" :
81
+ self .client = weaviate .connect_to_local ()
82
+ else :
83
+ self .client = weaviate .connect_to_wcs (
84
+ cluster_url = self .args ["url" ],
85
+ auth_credentials = weaviate .auth .AuthApiKey (self .args ["api_key" ]),
86
+ headers = {'X-OpenAI-Api-key' : self .args ["openai_api_key" ]}
87
+ if self .args ["openai_api_key" ]
88
+ else None ,
89
+ skip_init_checks = True ,
90
+ )
66
91
67
92
def get_index_names (self ):
68
93
if self .args .get ("classes" ) is None :
@@ -76,14 +101,60 @@ def get_index_names(self):
76
101
return [c for c in self .all_classes if c in input_classes ]
77
102
78
103
def get_data (self ):
79
- # Get all objects of a class
104
+ # Get the index names to export
80
105
index_names = self .get_index_names ()
81
- for class_name in index_names :
82
- collection = self .client .collections .get (class_name )
106
+ index_metas : Dict [str , List [NamespaceMeta ]] = {}
107
+
108
+ # Iterate over index names and fetch data
109
+ for index_name in index_names :
110
+ collection = self .client .collections .get (index_name )
83
111
response = collection .aggregate .over_all (total_count = True )
84
- print (f"{ response .total_count = } " )
112
+ total_vector_count = response .total_count
113
+
114
+ # Create vectors directory for this index
115
+ vectors_directory = self .create_vec_dir (index_name )
116
+
117
+ # Export data in batches
118
+ batch_size = self .args ["batch_size" ]
119
+ num_batches = (total_vector_count + batch_size - 1 ) // batch_size
120
+ num_vectors_exported = 0
121
+
122
+ for batch_idx in tqdm (range (num_batches ), desc = f"Exporting { index_name } " ):
123
+ offset = batch_idx * batch_size
124
+ objects = collection .objects .limit (batch_size ).offset (offset ).get ()
125
+
126
+ # Extract vectors and metadata
127
+ vectors = {obj .id : obj .vector for obj in objects }
128
+ metadata = {}
129
+ # Need a better way
130
+ for obj in objects :
131
+ metadata [obj .id ] = {attr : getattr (obj , attr ) for attr in dir (obj ) if not attr .startswith ("__" )}
132
+
133
+
134
+ # Save vectors and metadata to Parquet file
135
+ num_vectors_exported += self .save_vectors_to_parquet (
136
+ vectors , metadata , vectors_directory
137
+ )
138
+
139
+ # Create NamespaceMeta for this index
140
+ namespace_metas = [
141
+ self .get_namespace_meta (
142
+ index_name ,
143
+ vectors_directory ,
144
+ total = total_vector_count ,
145
+ num_vectors_exported = num_vectors_exported ,
146
+ dim = 300 , # Not sure of the dimensions
147
+ distance = "Cosine" ,
148
+ )
149
+ ]
150
+ index_metas [index_name ] = namespace_metas
151
+
152
+ # Write VDFMeta to JSON file
153
+ self .file_structure .append (os .path .join (self .vdf_directory , "VDF_META.json" ))
154
+ internal_metadata = self .get_basic_vdf_meta (index_metas )
155
+ meta_text = json .dumps (internal_metadata .model_dump (), indent = 4 )
156
+ tqdm .write (meta_text )
157
+
158
+ print ("Data export complete." )
85
159
86
- # objects = self.client.query.get(
87
- # wvq.Objects(wvq.Class(class_name)).with_limit(1000)
88
- # )
89
- # print(objects)
160
+ return True
0 commit comments