diff --git a/examples/perf_comparison_example.py b/examples/perf_comparison_example.py new file mode 100644 index 000000000..57411e760 --- /dev/null +++ b/examples/perf_comparison_example.py @@ -0,0 +1,46 @@ +import pandas as pd + +from xorq.common.utils.import_utils import ( + import_from_github, +) +from xorq.common.utils.perf_utils import ( + compare_runs, +) + + +tag = "v0.2.2" +lib = import_from_github( + "xorq-labs", "xorq", "examples/complex_cached_expr.py", tag=tag +) + + +if __name__ == "__main__": + (train_predicted, *_) = lib.make_exprs() + (cleared, uncached_df, cached_df) = compare_runs(train_predicted) + uncached_duration, cached_duration = ( + (df.end_datetime.max() - df.start_datetime.min()).total_seconds() + for df in (uncached_df, cached_df) + ) + delta_series = pd.Series( + { + "uncached_duration": uncached_duration, + "cached_duration": cached_duration, + "delta_duration": cached_duration - uncached_duration, + } + ) + (cache_miss_events, cache_hit_events) = ( + pd.concat( + ( + pd.DataFrame( + dct for dct in trace.cache_event_dcts if dct["name"] == name + ) + for trace in df.trace + ), + ignore_index=True, + ) + for (df, name) in ( + (uncached_df, "cache.miss"), + (cached_df, "cache.hit"), + ) + ) + print(delta_series.round(2)) diff --git a/python/xorq/common/utils/import_utils.py b/python/xorq/common/utils/import_utils.py index 833f4394c..d9ba287a4 100644 --- a/python/xorq/common/utils/import_utils.py +++ b/python/xorq/common/utils/import_utils.py @@ -83,9 +83,8 @@ def import_from_path(path, module_name="__main__"): ) -def import_from_gist(user, gist): - path = f"https://gist.githubusercontent.com/{user}/{gist}/raw/" - req = urllib.request.Request(path, method="GET") +def import_from_url(url): + req = urllib.request.Request(url, method="GET") resp = urllib.request.urlopen(req) if resp.code != 200: raise ValueError @@ -94,3 +93,22 @@ def import_from_gist(user, gist): path.write_text(resp.read().decode("ascii")) module = import_python(path) return module + + +def import_from_gist(user, gist): + url = f"https://gist.githubusercontent.com/{user}/{gist}/raw/" + return import_from_url(url) + + +def import_from_github(user, repo, path, *, tag=None, branch=None, commit=None): + if tag: + infix = f"refs/tags/{tag}" + elif branch: + infix = f"refs/heads/{branch}" + elif commit: + infix = commit + else: + raise ValueError("one of tag, branch, commit must be non None") + + url = f"https://raw.githubusercontent.com/{user}/{repo}/{infix}/{path}" + return import_from_url(url) diff --git a/python/xorq/common/utils/perf_utils.py b/python/xorq/common/utils/perf_utils.py new file mode 100644 index 000000000..b16351d4c --- /dev/null +++ b/python/xorq/common/utils/perf_utils.py @@ -0,0 +1,56 @@ +from datetime import datetime +from time import sleep + +import pandas as pd + +from xorq.common.utils.trace_utils import ( + Trace, +) + + +def clear_caches(expr): + def clear_cache(node): + from xorq.caching import ParquetStorage + from xorq.expr.relations import CachedNode + + assert isinstance(node, CachedNode) + expr = node.to_expr() + if expr.ls.exists(): + storage = expr.ls.storage + key = expr.ls.get_key() + if isinstance(storage, ParquetStorage): + key = node.storage.cache.storage.get_loc(key) + key.unlink() + else: + storage.cache.drop(node.parent) + return key + else: + return None + + return tuple(clear_cache(node) for node in expr.ls.cached_nodes) + + +def compare_runs(expr, sleep_duration=5): + cleared = clear_caches(expr) + first_cutoff = datetime.now() + expr.execute() + second_cutoff = datetime.now() + expr.execute() + sleep(sleep_duration) + (traces, partials) = Trace.process_path() + assert not partials + df = pd.DataFrame( + { + "trace_id": trace.trace_id, + "start_datetime": trace.start_datetime, + "end_datetime": trace.end_datetime, + "duration": trace.duration, + "trace": trace, + } + for trace in traces + ) + (first, second) = ( + df[lambda t: t.start_datetime.between(first_cutoff, second_cutoff)], + df[lambda t: t.start_datetime.ge(second_cutoff)], + ) + return cleared, first, second diff --git a/python/xorq/common/utils/trace_utils.py b/python/xorq/common/utils/trace_utils.py index 5d2faf636..45db3c7b1 100644 --- a/python/xorq/common/utils/trace_utils.py +++ b/python/xorq/common/utils/trace_utils.py @@ -191,6 +191,8 @@ def cache_event_dct(self): "duration": self.duration, "name": event.name, "key": attribute.value, + "start_datetime": self.start_datetime, + "end_datetime": self.end_datetime, } else: return None @@ -384,6 +386,21 @@ def get_lineage(self, span_id): lineage += (span,) return lineage + @property + def attribute_df(self): + import pandas as pd + + return pd.DataFrame( + { + f"attribute.{getattr(attribute, 'name')}": getattr(attribute, "value") + for attribute in event.attributes + } + | {f"event.{k}": getattr(event, k) for k in ("time", "name")} + | {k: getattr(span, k) for k in ("trace_id", "span_id", "name")} + for span in self.spans + for event in span.events + ) + def get_depth(self, depth): return self.get_depths().get(depth, ()) @@ -439,6 +456,10 @@ def duration(self): def start_datetime(self): return self.parent_span.start_datetime + @property + def end_datetime(self): + return self.parent_span.end_datetime + def get_spans_named(self, name): return tuple(span for span in self.spans if span.name == name) diff --git a/python/xorq/tests/test_examples.py b/python/xorq/tests/test_examples.py index cf3e81a7e..22d3f8372 100644 --- a/python/xorq/tests/test_examples.py +++ b/python/xorq/tests/test_examples.py @@ -14,6 +14,7 @@ "mcp_flight_server.py", "duckdb_flight_example.py", "complex_cached_expr.py", + "perf_comparison_example.py", ) file_path = pathlib.Path(__file__).absolute()