Skip to content

Commit 427caea

Browse files
committed
latest state and fixes from review
1 parent 83b01c0 commit 427caea

File tree

18 files changed

+1101
-375
lines changed

18 files changed

+1101
-375
lines changed

research/__init__.py

Whitespace-only changes.

research/multiprocesssing_communication_perf/README.md

Whitespace-only changes.

research/multiprocesssing_communication_perf/__init__.py

Whitespace-only changes.

research/multiprocesssing_communication_perf/requirements.txt

Whitespace-only changes.
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from __future__ import annotations
2+
3+
import csv
4+
import io
5+
import pickle
6+
import random
7+
import sys
8+
import time
9+
from typing import Any
10+
11+
import click
12+
import numpy as np
13+
from pydantic import BaseModel
14+
15+
from guidellm.utils import EncodingTypesAlias, MessageEncoding, SerializationTypesAlias
16+
17+
from .utils import create_all_test_objects
18+
19+
20+
def calculate_size(obj: Any) -> int:
21+
if isinstance(obj, BaseModel):
22+
return sys.getsizeof(obj.__dict__)
23+
24+
if isinstance(obj, (tuple, list)) and any(
25+
isinstance(item, BaseModel) for item in obj
26+
):
27+
return sum(
28+
sys.getsizeof(item.__dict__)
29+
if isinstance(item, BaseModel)
30+
else sys.getsizeof(item)
31+
for item in obj
32+
)
33+
elif isinstance(obj, dict) and any(
34+
isinstance(value, BaseModel) for value in obj.values()
35+
):
36+
return sum(
37+
sys.getsizeof(value.__dict__)
38+
if isinstance(value, BaseModel)
39+
else sys.getsizeof(value)
40+
for value in obj.values()
41+
if isinstance(value, BaseModel)
42+
)
43+
44+
return sys.getsizeof(obj)
45+
46+
47+
def time_encode_decode(
48+
objects: list[Any],
49+
serialization: SerializationTypesAlias,
50+
encoding: EncodingTypesAlias,
51+
pydantic_models: list[type[BaseModel]] | None,
52+
num_iterations: int,
53+
) -> tuple[float, float, float, float]:
54+
message_encoding = MessageEncoding(serialization=serialization, encoding=encoding)
55+
if pydantic_models:
56+
for model in pydantic_models:
57+
message_encoding.register_pydantic(model)
58+
msg_sizes = []
59+
decoded = []
60+
encode_time = 0.0
61+
decode_time = 0.0
62+
63+
for _ in range(num_iterations):
64+
for obj in objects:
65+
start = time.perf_counter_ns()
66+
message = message_encoding.encode(obj)
67+
pickled_msg = pickle.dumps(message)
68+
end = time.perf_counter_ns()
69+
encode_time += end - start
70+
71+
msg_sizes.append(calculate_size(pickled_msg))
72+
73+
start = time.perf_counter_ns()
74+
message = pickle.loads(pickled_msg)
75+
decoded.append(message_encoding.decode(message=message))
76+
end = time.perf_counter_ns()
77+
decode_time += end - start
78+
79+
correct = 0
80+
for obj, dec in zip(objects, decoded):
81+
if (
82+
obj == dec
83+
or type(obj) is type(dec)
84+
and (
85+
(
86+
hasattr(obj, "model_dump")
87+
and hasattr(dec, "model_dump")
88+
and obj.model_dump() == dec.model_dump()
89+
)
90+
or str(obj) == str(dec)
91+
)
92+
):
93+
correct += 1
94+
95+
percent_differences = 100.0 * correct / len(objects)
96+
avg_msg_size = np.mean(msg_sizes)
97+
98+
return (
99+
encode_time / len(objects),
100+
decode_time / len(objects),
101+
avg_msg_size,
102+
percent_differences,
103+
)
104+
105+
106+
def run_benchmarks(objects_size: int, num_objects: int, num_iterations: int):
107+
results = {}
108+
109+
for obj_type, objects, pydantic_models in create_all_test_objects(
110+
objects_size=objects_size,
111+
num_objects=num_objects,
112+
):
113+
for serialization in ("dict", "sequence", None):
114+
for encoding in ("msgpack", "msgspec", None):
115+
try:
116+
encode_time, decode_time, avg_msg_size, percent_differences = (
117+
time_encode_decode(
118+
objects=objects,
119+
serialization=serialization,
120+
encoding=encoding,
121+
pydantic_models=pydantic_models,
122+
num_iterations=num_iterations,
123+
)
124+
)
125+
error = None
126+
except Exception as err:
127+
print(
128+
f"Error occurred while benchmarking {obj_type} for "
129+
f"serialization={serialization} and encoding={encoding}: {err}"
130+
)
131+
error = err
132+
encode_time = None
133+
decode_time = None
134+
avg_msg_size = None
135+
percent_differences = None
136+
137+
results[f"{obj_type}_{serialization}_{encoding}"] = {
138+
"obj_type": obj_type,
139+
"serialization": serialization,
140+
"encoding": encoding,
141+
"encode_time": encode_time,
142+
"decode_time": decode_time,
143+
"total_time": (
144+
encode_time + decode_time
145+
if encode_time is not None and decode_time is not None
146+
else None
147+
),
148+
"avg_msg_size": avg_msg_size,
149+
"percent_differences": percent_differences,
150+
"err": error,
151+
}
152+
153+
# Print results as a CSV table
154+
155+
# Create CSV output
156+
output = io.StringIO()
157+
writer = csv.writer(output)
158+
159+
# Write header
160+
writer.writerow(
161+
[
162+
"Object Type",
163+
"Serialization",
164+
"Encoding",
165+
"Encode Time (ns)",
166+
"Decode Time (ns)",
167+
"Total Time (ns)",
168+
"Avg Message Size (bytes)",
169+
"Accuracy (%)",
170+
"Error",
171+
]
172+
)
173+
174+
# Write data rows
175+
for result in results.values():
176+
writer.writerow(
177+
[
178+
result["obj_type"],
179+
result["serialization"],
180+
result["encoding"],
181+
result["encode_time"],
182+
result["decode_time"],
183+
result["total_time"],
184+
result["avg_msg_size"],
185+
result["percent_differences"],
186+
result["err"],
187+
]
188+
)
189+
190+
# Print the CSV table
191+
print(output.getvalue())
192+
193+
194+
@click.command()
195+
@click.option("--size", default=1024, type=int, help="Size of each object in bytes")
196+
@click.option(
197+
"--objects", default=1000, type=int, help="Number of objects to benchmark"
198+
)
199+
@click.option("--iterations", default=5, type=int, help="Number of iterations to run")
200+
def main(size, objects, iterations):
201+
random.seed(42)
202+
run_benchmarks(objects_size=size, num_objects=objects, num_iterations=iterations)
203+
204+
205+
if __name__ == "__main__":
206+
run_benchmarks(objects_size=1024, num_objects=10, num_iterations=5)

0 commit comments

Comments
 (0)