Skip to content

Commit 5e2a592

Browse files
committed
testing
1 parent 9cec7fe commit 5e2a592

File tree

2 files changed

+224
-14
lines changed

2 files changed

+224
-14
lines changed

src/vdf_io/export_vdf/weaviate_export.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,18 @@
22

33
from tqdm import tqdm
44
import weaviate
5+
import json
56

67
from vdf_io.export_vdf.vdb_export_cls import ExportVDB
8+
from vdf_io.meta_types import NamespaceMeta
79
from vdf_io.names import DBNames
810
from vdf_io.util import set_arg_from_input, set_arg_from_password
11+
from typing import Dict, List
912

1013
# Set these environment variables
1114
URL = os.getenv("YOUR_WCS_URL")
1215
APIKEY = os.getenv("YOUR_WCS_API_KEY")
16+
OPENAI_APIKEY = os.getenv("OPENAI_APIKEY")
1317

1418

1519
class ExportWeaviate(ExportVDB):
@@ -23,6 +27,15 @@ def make_parser(cls, subparsers):
2327

2428
parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance")
2529
parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key")
30+
parser_weaviate.add_argument("--openai_api_key", type=str, help="Openai API key")
31+
parser_weaviate.add_arguments(
32+
"--batch_size", type=int, help="batch size for fetching",
33+
default=1000
34+
)
35+
parser_weaviate.add_argument(
36+
"--connection-type", type=str, choices=["local", "cloud"], default="cloud",
37+
help="Type of connection to Weaviate (local or cloud)"
38+
)
2639
parser_weaviate.add_argument(
2740
"--classes", type=str, help="Classes to export (comma-separated)"
2841
)
@@ -35,6 +48,12 @@ def export_vdb(cls, args):
3548
"Enter the URL of Weaviate instance: ",
3649
str,
3750
)
51+
set_arg_from_input(
52+
args,
53+
"connection_type",
54+
"Enter 'local' or 'cloud' for connection types: ",
55+
choices=['local', 'cloud'],
56+
)
3857
set_arg_from_password(
3958
args,
4059
"api_key",
@@ -55,14 +74,20 @@ def export_vdb(cls, args):
5574
weaviate_export.get_data()
5675
return weaviate_export
5776

58-
# Connect to a WCS instance
77+
# Connect to a WCS or local instance
5978
def __init__(self, args):
6079
super().__init__(args)
61-
self.client = weaviate.connect_to_wcs(
62-
cluster_url=self.args["url"],
63-
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
64-
skip_init_checks=True,
65-
)
80+
if self.args["connection_type"] == "local":
81+
self.client = weaviate.connect_to_local()
82+
else:
83+
self.client = weaviate.connect_to_wcs(
84+
cluster_url=self.args["url"],
85+
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
86+
headers={'X-OpenAI-Api-key': self.args["openai_api_key"]}
87+
if self.args["openai_api_key"]
88+
else None,
89+
skip_init_checks=True,
90+
)
6691

6792
def get_index_names(self):
6893
if self.args.get("classes") is None:
@@ -76,14 +101,60 @@ def get_index_names(self):
76101
return [c for c in self.all_classes if c in input_classes]
77102

78103
def get_data(self):
79-
# Get all objects of a class
104+
# Get the index names to export
80105
index_names = self.get_index_names()
81-
for class_name in index_names:
82-
collection = self.client.collections.get(class_name)
106+
index_metas: Dict[str, List[NamespaceMeta]] = {}
107+
108+
# Iterate over index names and fetch data
109+
for index_name in index_names:
110+
collection = self.client.collections.get(index_name)
83111
response = collection.aggregate.over_all(total_count=True)
84-
print(f"{response.total_count=}")
112+
total_vector_count = response.total_count
113+
114+
# Create vectors directory for this index
115+
vectors_directory = self.create_vec_dir(index_name)
116+
117+
# Export data in batches
118+
batch_size = self.args["batch_size"]
119+
num_batches = (total_vector_count + batch_size - 1) // batch_size
120+
num_vectors_exported = 0
121+
122+
for batch_idx in tqdm(range(num_batches), desc=f"Exporting {index_name}"):
123+
offset = batch_idx * batch_size
124+
objects = collection.objects.limit(batch_size).offset(offset).get()
125+
126+
# Extract vectors and metadata
127+
vectors = {obj.id: obj.vector for obj in objects}
128+
metadata = {}
129+
# Need a better way
130+
for obj in objects:
131+
metadata[obj.id] = {attr: getattr(obj, attr) for attr in dir(obj) if not attr.startswith("__")}
132+
133+
134+
# Save vectors and metadata to Parquet file
135+
num_vectors_exported += self.save_vectors_to_parquet(
136+
vectors, metadata, vectors_directory
137+
)
138+
139+
# Create NamespaceMeta for this index
140+
namespace_metas = [
141+
self.get_namespace_meta(
142+
index_name,
143+
vectors_directory,
144+
total=total_vector_count,
145+
num_vectors_exported=num_vectors_exported,
146+
dim=300, # Not sure of the dimensions
147+
distance="Cosine",
148+
)
149+
]
150+
index_metas[index_name] = namespace_metas
151+
152+
# Write VDFMeta to JSON file
153+
self.file_structure.append(os.path.join(self.vdf_directory, "VDF_META.json"))
154+
internal_metadata = self.get_basic_vdf_meta(index_metas)
155+
meta_text = json.dumps(internal_metadata.model_dump(), indent=4)
156+
tqdm.write(meta_text)
157+
158+
print("Data export complete.")
85159

86-
# objects = self.client.query.get(
87-
# wvq.Objects(wvq.Class(class_name)).with_limit(1000)
88-
# )
89-
# print(objects)
160+
return True
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import os
2+
import weaviate
3+
import json
4+
from tqdm import tqdm
5+
from vdf_io.import_vdf.vdf_import_cls import ImportVDB
6+
from vdf_io.names import DBNames
7+
from vdf_io.util import set_arg_from_input, set_arg_from_password
8+
9+
# Set these environment variables
10+
URL = os.getenv("YOUR_WCS_URL")
11+
APIKEY = os.getenv("YOUR_WCS_API_KEY")
12+
13+
14+
class ImportWeaviate(ImportVDB):
15+
DB_NAME_SLUG = DBNames.WEAVIATE
16+
17+
@classmethod
18+
def make_parser(cls, subparsers):
19+
parser_weaviate = subparsers.add_parser(
20+
cls.DB_NAME_SLUG, help="Import data into Weaviate"
21+
)
22+
23+
parser_weaviate.add_argument("--url", type=str, help="URL of Weaviate instance")
24+
parser_weaviate.add_argument("--api_key", type=str, help="Weaviate API key")
25+
parser_weaviate.add_argument(
26+
"--index_name", type=str, help="Name of the index in Weaviate"
27+
)
28+
29+
@classmethod
30+
def import_vdb(cls, args):
31+
set_arg_from_input(
32+
args,
33+
"url",
34+
"Enter the URL of Weaviate instance: ",
35+
str,
36+
)
37+
set_arg_from_password(
38+
args,
39+
"api_key",
40+
"Enter the Weaviate API key: ",
41+
"WEAVIATE_API_KEY",
42+
)
43+
set_arg_from_input(
44+
args,
45+
"index_name",
46+
"Enter the name of the index in Weaviate: ",
47+
str,
48+
)
49+
weaviate_import = ImportWeaviate(args)
50+
weaviate_import.upsert_data()
51+
return weaviate_import
52+
53+
def __init__(self, args):
54+
super().__init__(args)
55+
if self.args["connection_type"] == "local":
56+
self.client = weaviate.connect_to_local()
57+
else:
58+
self.client = weaviate.connect_to_wcs(
59+
cluster_url=self.args["url"],
60+
auth_credentials=weaviate.auth.AuthApiKey(self.args["api_key"]),
61+
headers={'X-OpenAI-Api-key': self.args["openai_api_key"]}
62+
if self.args["openai_api_key"]
63+
else None,
64+
skip_init_checks=True,
65+
)
66+
67+
def upsert_data(self):
68+
max_hit = False
69+
total_imported_count = 0
70+
71+
# Iterate over the indexes and import the data
72+
for index_name, index_meta in tqdm(self.vdf_meta["indexes"].items(), desc="Importing indexes"):
73+
tqdm.write(f"Importing data for index '{index_name}'")
74+
for namespace_meta in index_meta:
75+
self.set_dims(namespace_meta, index_name)
76+
77+
# Create or get the index
78+
index_name = self.create_new_name(index_name, self.client.collections.list_all().keys())
79+
index = self.client.collections.get(index_name)
80+
81+
# Load data from the Parquet files
82+
data_path = namespace_meta["data_path"]
83+
final_data_path = self.get_final_data_path(data_path)
84+
parquet_files = self.get_parquet_files(final_data_path)
85+
86+
vectors = {}
87+
metadata = {}
88+
89+
# for file in tqdm(parquet_files, desc="Loading data from parquet files"):
90+
# file_path = os.path.join(final_data_path, file)
91+
# df = self.read_parquet_progress(file_path)
92+
93+
# if len(vectors) > (self.args.get("max_num_rows") or INT_MAX):
94+
# max_hit = True
95+
# break
96+
97+
# self.update_vectors(vectors, vector_column_name, df)
98+
# self.update_metadata(metadata, vector_column_names, df)
99+
# if max_hit:
100+
# break
101+
102+
# tqdm.write(f"Loaded {len(vectors)} vectors from {len(parquet_files)} parquet files")
103+
104+
# # Upsert the vectors and metadata to the Weaviate index in batches
105+
# BATCH_SIZE = self.args.get("batch_size", 1000) or 1000
106+
# current_batch_size = BATCH_SIZE
107+
# start_idx = 0
108+
109+
# while start_idx < len(vectors):
110+
# end_idx = min(start_idx + current_batch_size, len(vectors))
111+
112+
# batch_vectors = [
113+
# (
114+
# str(id),
115+
# vector,
116+
# {
117+
# k: v
118+
# for k, v in metadata.get(id, {}).items()
119+
# if v is not None
120+
# } if len(metadata.get(id, {}).keys()) > 0 else None
121+
# )
122+
# for id, vector in list(vectors.items())[start_idx:end_idx]
123+
# ]
124+
125+
# try:
126+
# resp = index.batch.create(batch_vectors)
127+
# total_imported_count += len(batch_vectors)
128+
# start_idx += len(batch_vectors)
129+
# except Exception as e:
130+
# tqdm.write(f"Error upserting vectors for index '{index_name}', {e}")
131+
# if current_batch_size < BATCH_SIZE / 100:
132+
# tqdm.write("Batch size is not the issue. Aborting import")
133+
# raise e
134+
# current_batch_size = int(2 * current_batch_size / 3)
135+
# tqdm.write(f"Reducing batch size to {current_batch_size}")
136+
# continue
137+
138+
# tqdm.write(f"Data import completed successfully. Imported {total_imported_count} vectors")
139+
# self.args["imported_count"] = total_imported_count

0 commit comments

Comments
 (0)