Skip to content

Commit fa5f40b

Browse files
committed
changes
1 parent 5e2a592 commit fa5f40b

File tree

2 files changed

+102
-84
lines changed

2 files changed

+102
-84
lines changed

src/vdf_io/export_vdf/weaviate_export.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
2-
3-
from tqdm import tqdm
42
import weaviate
53
import json
64

5+
from tqdm import tqdm
6+
from weaviate.classes.query import MetadataQuery
77
from vdf_io.export_vdf.vdb_export_cls import ExportVDB
88
from vdf_io.meta_types import NamespaceMeta
99
from vdf_io.names import DBNames
1010
from vdf_io.util import set_arg_from_input, set_arg_from_password
11+
from vdf_io.constants import DEFAULT_BATCH_SIZE
1112
from typing import Dict, List
1213

1314
# Set these environment variables
@@ -28,9 +29,13 @@ def make_parser(cls, subparsers):
2829
parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance")
2930
parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key")
3031
parser_weaviate.add_argument("--openai_api_key", type=str, help="Openai API key")
31-
parser_weaviate.add_arguments(
32+
parser_weaviate.add_argument(
3233
"--batch_size", type=int, help="batch size for fetching",
33-
default=1000
34+
default=DEFAULT_BATCH_SIZE
35+
)
36+
parser_weaviate.add_argument(
37+
"--offset", type=int, help="offset for fetching",
38+
default=None
3439
)
3540
parser_weaviate.add_argument(
3641
"--connection-type", type=str, choices=["local", "cloud"], default="cloud",
@@ -100,39 +105,50 @@ def get_index_names(self):
100105
)
101106
return [c for c in self.all_classes if c in input_classes]
102107

108+
def metadata_to_dict(self, metadata):
109+
meta_data = {}
110+
meta_data["creation_time"] = metadata.creation_time
111+
meta_data["distance"] = metadata.distance
112+
meta_data["certainty"] = metadata.certainty
113+
meta_data["explain_score"] = metadata.explain_score
114+
meta_data["is_consistent"] = metadata.is_consistent
115+
meta_data["last_update_time"] = metadata.last_update_time
116+
meta_data["rerank_score"] = metadata.rerank_score
117+
meta_data["score"] = metadata.score
118+
119+
return meta_data
120+
103121
def get_data(self):
104122
# Get the index names to export
105123
index_names = self.get_index_names()
106124
index_metas: Dict[str, List[NamespaceMeta]] = {}
107125

126+
# Export data in batches
127+
batch_size = self.args["batch_size"]
128+
offset = self.args["offset"]
129+
108130
# Iterate over index names and fetch data
109131
for index_name in index_names:
110132
collection = self.client.collections.get(index_name)
111-
response = collection.aggregate.over_all(total_count=True)
112-
total_vector_count = response.total_count
133+
response = collection.query.fetch_objects(
134+
limit=batch_size,
135+
offset=offset,
136+
include_vector=True,
137+
return_metadata=MetadataQuery.full()
138+
)
139+
res = collection.aggregate.over_all(total_count=True)
140+
total_vector_count = res.total_count
113141

114142
# Create vectors directory for this index
115143
vectors_directory = self.create_vec_dir(index_name)
116144

117-
# Export data in batches
118-
batch_size = self.args["batch_size"]
119-
num_batches = (total_vector_count + batch_size - 1) // batch_size
120-
num_vectors_exported = 0
121-
122-
for batch_idx in tqdm(range(num_batches), desc=f"Exporting {index_name}"):
123-
offset = batch_idx * batch_size
124-
objects = collection.objects.limit(batch_size).offset(offset).get()
125-
126-
# Extract vectors and metadata
127-
vectors = {obj.id: obj.vector for obj in objects}
128-
metadata = {}
129-
# Need a better way
130-
for obj in objects:
131-
metadata[obj.id] = {attr: getattr(obj, attr) for attr in dir(obj) if not attr.startswith("__")}
132-
145+
for obj in response.objects:
146+
vectors = obj.vector
147+
metadata = obj.metadata
148+
metadata = self.metadata_to_dict(metadata=metadata)
133149

134150
# Save vectors and metadata to Parquet file
135-
num_vectors_exported += self.save_vectors_to_parquet(
151+
num_vectors_exported = self.save_vectors_to_parquet(
136152
vectors, metadata, vectors_directory
137153
)
138154

@@ -143,7 +159,7 @@ def get_data(self):
143159
vectors_directory,
144160
total=total_vector_count,
145161
num_vectors_exported=num_vectors_exported,
146-
dim=300, # Not sure of the dimensions
162+
dim=-1,
147163
distance="Cosine",
148164
)
149165
]
@@ -154,7 +170,8 @@ def get_data(self):
154170
internal_metadata = self.get_basic_vdf_meta(index_metas)
155171
meta_text = json.dumps(internal_metadata.model_dump(), indent=4)
156172
tqdm.write(meta_text)
157-
173+
with open(os.path.join(self.vdf_directory, "VDF_META.json"), "w") as json_file:
174+
json_file.write(meta_text)
158175
print("Data export complete.")
159176

160177
return True
Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import os
22
import weaviate
3-
import json
43
from tqdm import tqdm
54
from vdf_io.import_vdf.vdf_import_cls import ImportVDB
65
from vdf_io.names import DBNames
76
from vdf_io.util import set_arg_from_input, set_arg_from_password
7+
from vdf_io.constants import INT_MAX, DEFAULT_BATCH_SIZE
88

99
# Set these environment variables
1010
URL = os.getenv("YOUR_WCS_URL")
@@ -25,6 +25,14 @@ def make_parser(cls, subparsers):
2525
parser_weaviate.add_argument(
2626
"--index_name", type=str, help="Name of the index in Weaviate"
2727
)
28+
parser_weaviate.add_argument(
29+
"--connection-type", type=str, choices=["local", "cloud"], default="cloud",
30+
help="Type of connection to Weaviate (local or cloud)"
31+
)
32+
parser_weaviate.add_argument(
33+
"--batch_size", type=int, help="batch size for fetching",
34+
default=DEFAULT_BATCH_SIZE
35+
)
2836

2937
@classmethod
3038
def import_vdb(cls, args):
@@ -34,18 +42,24 @@ def import_vdb(cls, args):
3442
"Enter the URL of Weaviate instance: ",
3543
str,
3644
)
37-
set_arg_from_password(
38-
args,
39-
"api_key",
40-
"Enter the Weaviate API key: ",
41-
"WEAVIATE_API_KEY",
42-
)
4345
set_arg_from_input(
4446
args,
4547
"index_name",
4648
"Enter the name of the index in Weaviate: ",
4749
str,
4850
)
51+
set_arg_from_input(
52+
args,
53+
"connection_type",
54+
"Enter 'local' or 'cloud' for connection types: ",
55+
choices=['local', 'cloud'],
56+
)
57+
set_arg_from_password(
58+
args,
59+
"api_key",
60+
"Enter the Weaviate API key: ",
61+
"WEAVIATE_API_KEY",
62+
)
4963
weaviate_import = ImportWeaviate(args)
5064
weaviate_import.upsert_data()
5165
return weaviate_import
@@ -76,7 +90,6 @@ def upsert_data(self):
7690

7791
# Create or get the index
7892
index_name = self.create_new_name(index_name, self.client.collections.list_all().keys())
79-
index = self.client.collections.get(index_name)
8093

8194
# Load data from the Parquet files
8295
data_path = namespace_meta["data_path"]
@@ -85,55 +98,43 @@ def upsert_data(self):
8598

8699
vectors = {}
87100
metadata = {}
101+
vector_column_names, vector_column_name = self.get_vector_column_name(
102+
index_name, namespace_meta
103+
)
88104

89-
# for file in tqdm(parquet_files, desc="Loading data from parquet files"):
90-
# file_path = os.path.join(final_data_path, file)
91-
# df = self.read_parquet_progress(file_path)
92-
93-
# if len(vectors) > (self.args.get("max_num_rows") or INT_MAX):
94-
# max_hit = True
95-
# break
96-
97-
# self.update_vectors(vectors, vector_column_name, df)
98-
# self.update_metadata(metadata, vector_column_names, df)
99-
# if max_hit:
100-
# break
101-
102-
# tqdm.write(f"Loaded {len(vectors)} vectors from {len(parquet_files)} parquet files")
103-
104-
# # Upsert the vectors and metadata to the Weaviate index in batches
105-
# BATCH_SIZE = self.args.get("batch_size", 1000) or 1000
106-
# current_batch_size = BATCH_SIZE
107-
# start_idx = 0
108-
109-
# while start_idx < len(vectors):
110-
# end_idx = min(start_idx + current_batch_size, len(vectors))
111-
112-
# batch_vectors = [
113-
# (
114-
# str(id),
115-
# vector,
116-
# {
117-
# k: v
118-
# for k, v in metadata.get(id, {}).items()
119-
# if v is not None
120-
# } if len(metadata.get(id, {}).keys()) > 0 else None
121-
# )
122-
# for id, vector in list(vectors.items())[start_idx:end_idx]
123-
# ]
124-
125-
# try:
126-
# resp = index.batch.create(batch_vectors)
127-
# total_imported_count += len(batch_vectors)
128-
# start_idx += len(batch_vectors)
129-
# except Exception as e:
130-
# tqdm.write(f"Error upserting vectors for index '{index_name}', {e}")
131-
# if current_batch_size < BATCH_SIZE / 100:
132-
# tqdm.write("Batch size is not the issue. Aborting import")
133-
# raise e
134-
# current_batch_size = int(2 * current_batch_size / 3)
135-
# tqdm.write(f"Reducing batch size to {current_batch_size}")
136-
# continue
137-
138-
# tqdm.write(f"Data import completed successfully. Imported {total_imported_count} vectors")
139-
# self.args["imported_count"] = total_imported_count
105+
for file in tqdm(parquet_files, desc="Loading data from parquet files"):
106+
file_path = os.path.join(final_data_path, file)
107+
df = self.read_parquet_progress(file_path)
108+
109+
if len(vectors) > (self.args.get("max_num_rows") or INT_MAX):
110+
max_hit = True
111+
break
112+
if len(vectors) + len(df) > (
113+
self.args.get("max_num_rows") or INT_MAX
114+
):
115+
df = df.head(
116+
(self.args.get("max_num_rows") or INT_MAX) - len(vectors)
117+
)
118+
max_hit = True
119+
self.update_vectors(vectors, vector_column_name, df)
120+
self.update_metadata(metadata, vector_column_names, df)
121+
if max_hit:
122+
break
123+
124+
tqdm.write(f"Loaded {len(vectors)} vectors from {len(parquet_files)} parquet files")
125+
126+
# Upsert the vectors and metadata to the Weaviate index in batches
127+
BATCH_SIZE = self.args.get("batch_size")
128+
129+
with self.client.batch.fixed_size(batch_size=BATCH_SIZE) as batch:
130+
for _, vector in vectors.items():
131+
batch.add_object(
132+
vector=vector,
133+
collection=index_name
134+
#TODO: Find way to add Metadata
135+
)
136+
total_imported_count += 1
137+
138+
139+
tqdm.write(f"Data import completed successfully. Imported {total_imported_count} vectors")
140+
self.args["imported_count"] = total_imported_count

0 commit comments

Comments
 (0)