1
1
import os
2
2
import weaviate
3
- import json
4
3
from tqdm import tqdm
5
4
from vdf_io .import_vdf .vdf_import_cls import ImportVDB
6
5
from vdf_io .names import DBNames
7
6
from vdf_io .util import set_arg_from_input , set_arg_from_password
7
+ from vdf_io .constants import INT_MAX , DEFAULT_BATCH_SIZE
8
8
9
9
# Set these environment variables
10
10
URL = os .getenv ("YOUR_WCS_URL" )
@@ -25,6 +25,14 @@ def make_parser(cls, subparsers):
25
25
parser_weaviate .add_argument (
26
26
"--index_name" , type = str , help = "Name of the index in Weaviate"
27
27
)
28
+ parser_weaviate .add_argument (
29
+ "--connection-type" , type = str , choices = ["local" , "cloud" ], default = "cloud" ,
30
+ help = "Type of connection to Weaviate (local or cloud)"
31
+ )
32
+ parser_weaviate .add_argument (
33
+ "--batch_size" , type = int , help = "batch size for fetching" ,
34
+ default = DEFAULT_BATCH_SIZE
35
+ )
28
36
29
37
@classmethod
30
38
def import_vdb (cls , args ):
@@ -34,18 +42,24 @@ def import_vdb(cls, args):
34
42
"Enter the URL of Weaviate instance: " ,
35
43
str ,
36
44
)
37
- set_arg_from_password (
38
- args ,
39
- "api_key" ,
40
- "Enter the Weaviate API key: " ,
41
- "WEAVIATE_API_KEY" ,
42
- )
43
45
set_arg_from_input (
44
46
args ,
45
47
"index_name" ,
46
48
"Enter the name of the index in Weaviate: " ,
47
49
str ,
48
50
)
51
+ set_arg_from_input (
52
+ args ,
53
+ "connection_type" ,
54
+ "Enter 'local' or 'cloud' for connection types: " ,
55
+ choices = ['local' , 'cloud' ],
56
+ )
57
+ set_arg_from_password (
58
+ args ,
59
+ "api_key" ,
60
+ "Enter the Weaviate API key: " ,
61
+ "WEAVIATE_API_KEY" ,
62
+ )
49
63
weaviate_import = ImportWeaviate (args )
50
64
weaviate_import .upsert_data ()
51
65
return weaviate_import
@@ -76,7 +90,6 @@ def upsert_data(self):
76
90
77
91
# Create or get the index
78
92
index_name = self .create_new_name (index_name , self .client .collections .list_all ().keys ())
79
- index = self .client .collections .get (index_name )
80
93
81
94
# Load data from the Parquet files
82
95
data_path = namespace_meta ["data_path" ]
@@ -85,55 +98,43 @@ def upsert_data(self):
85
98
86
99
vectors = {}
87
100
metadata = {}
101
+ vector_column_names , vector_column_name = self .get_vector_column_name (
102
+ index_name , namespace_meta
103
+ )
88
104
89
- # for file in tqdm(parquet_files, desc="Loading data from parquet files"):
90
- # file_path = os.path.join(final_data_path, file)
91
- # df = self.read_parquet_progress(file_path)
92
-
93
- # if len(vectors) > (self.args.get("max_num_rows") or INT_MAX):
94
- # max_hit = True
95
- # break
96
-
97
- # self.update_vectors(vectors, vector_column_name, df)
98
- # self.update_metadata(metadata, vector_column_names, df)
99
- # if max_hit:
100
- # break
101
-
102
- # tqdm.write(f"Loaded {len(vectors)} vectors from {len(parquet_files)} parquet files")
103
-
104
- # # Upsert the vectors and metadata to the Weaviate index in batches
105
- # BATCH_SIZE = self.args.get("batch_size", 1000) or 1000
106
- # current_batch_size = BATCH_SIZE
107
- # start_idx = 0
108
-
109
- # while start_idx < len(vectors):
110
- # end_idx = min(start_idx + current_batch_size, len(vectors))
111
-
112
- # batch_vectors = [
113
- # (
114
- # str(id),
115
- # vector,
116
- # {
117
- # k: v
118
- # for k, v in metadata.get(id, {}).items()
119
- # if v is not None
120
- # } if len(metadata.get(id, {}).keys()) > 0 else None
121
- # )
122
- # for id, vector in list(vectors.items())[start_idx:end_idx]
123
- # ]
124
-
125
- # try:
126
- # resp = index.batch.create(batch_vectors)
127
- # total_imported_count += len(batch_vectors)
128
- # start_idx += len(batch_vectors)
129
- # except Exception as e:
130
- # tqdm.write(f"Error upserting vectors for index '{index_name}', {e}")
131
- # if current_batch_size < BATCH_SIZE / 100:
132
- # tqdm.write("Batch size is not the issue. Aborting import")
133
- # raise e
134
- # current_batch_size = int(2 * current_batch_size / 3)
135
- # tqdm.write(f"Reducing batch size to {current_batch_size}")
136
- # continue
137
-
138
- # tqdm.write(f"Data import completed successfully. Imported {total_imported_count} vectors")
139
- # self.args["imported_count"] = total_imported_count
105
+ for file in tqdm (parquet_files , desc = "Loading data from parquet files" ):
106
+ file_path = os .path .join (final_data_path , file )
107
+ df = self .read_parquet_progress (file_path )
108
+
109
+ if len (vectors ) > (self .args .get ("max_num_rows" ) or INT_MAX ):
110
+ max_hit = True
111
+ break
112
+ if len (vectors ) + len (df ) > (
113
+ self .args .get ("max_num_rows" ) or INT_MAX
114
+ ):
115
+ df = df .head (
116
+ (self .args .get ("max_num_rows" ) or INT_MAX ) - len (vectors )
117
+ )
118
+ max_hit = True
119
+ self .update_vectors (vectors , vector_column_name , df )
120
+ self .update_metadata (metadata , vector_column_names , df )
121
+ if max_hit :
122
+ break
123
+
124
+ tqdm .write (f"Loaded { len (vectors )} vectors from { len (parquet_files )} parquet files" )
125
+
126
+ # Upsert the vectors and metadata to the Weaviate index in batches
127
+ BATCH_SIZE = self .args .get ("batch_size" )
128
+
129
+ with self .client .batch .fixed_size (batch_size = BATCH_SIZE ) as batch :
130
+ for _ , vector in vectors .items ():
131
+ batch .add_object (
132
+ vector = vector ,
133
+ collection = index_name
134
+ #TODO: Find way to add Metadata
135
+ )
136
+ total_imported_count += 1
137
+
138
+
139
+ tqdm .write (f"Data import completed successfully. Imported { total_imported_count } vectors" )
140
+ self .args ["imported_count" ] = total_imported_count
0 commit comments