Skip to content

Commit d698d66

Browse files
authored
【Hackathon 9th Sprint No.9】feat: implement ES(t) macro/micro cross-validation and refactor analysis utilities (#363)
* feat: implement ES(t) macro/micro cross-validation and refactor analysis utilities This commit implements the Error-aware Speedup Score (ES_t) metric from Section 3.2.2 of the technical report (arXiv:2510.24035), along with the mathematical proofs from Appendix B and C that establish the sample-level validity of both S_t and ES_t metrics. Key Features: ============= 1. Appendix B Implementation - Sample-level proof for S_t: - Micro-level calculation: geometric mean of rectified speedups for all samples - Macro-level calculation: S_t = α^λ · β^(ληp) · b^(1-λ) - Cross-validation: both methods produce identical results, proving S_t is equivalent to the geometric mean of sample-level rectified speedups 2. Appendix C Implementation - Sample-level proof for ES_t: - Micro-level calculation: geometric mean of error-aware rectified speedups - Macro-level calculation: ES_t = α^λ · β^(ληp) · γ_t^(1-λ) - Dynamic penalty factor: γ_t = b^(sum(π_c * indicator(t < c))) - Cross-validation: validates that ES_t is the geometric mean of error-aware rectified speedups, where failure samples use type-specific dynamic penalties instead of fixed penalty b 3. Error-aware design (Section 3.2.2): - Error type classification: c=1 (accuracy), c=2 (runtime crash), c=3 (compile failure) - Tiered tolerance rules: t≥1 tolerates accuracy errors, t≥2 tolerates runtime crashes, t≥3 tolerates all errors - Dynamic penalty γ_t adapts based on error type distribution and tolerance level 4. Independent verification script: - verify_macro_params.py: calculates and prints all macro parameters (alpha, beta, gamma, lambda, eta, pi) independently - Enables validation of plot_ESt results by computing each parameter separately 5. Mandatory validation mechanism: - plot_ESt.py: enforces macro/micro result matching before adoption - Rejects results if validation fails, ensuring calculation correctness 6. Code refactoring for maintainability: - macro_statistics.py: dedicated module for macro parameter calculations - Each parameter has independent function (alpha, beta, gamma, lambda, eta, pi) - Reduced nesting levels in analysis_util.py by extracting helper functions - Simplified scan_all_folders and added .txt file support - Improved code organization following software engineering best practices Technical Details: ================== - Micro calculation: processes each sample individually, applies rectified speedup rules, then computes geometric mean - Macro calculation: uses aggregated statistics (correct count, speedup distributions, error type proportions) to compute expected values - Validation: compares micro and macro results with tolerance threshold (1e-6) - All calculations verified against real benchmark data (118 samples) Files Changed: ============== - graph_net/analysis_util.py: refactored with helper functions, integrated macro_statistics module, reduced nesting, simplified scan_all_folders - graph_net/macro_statistics.py: new module for macro parameter calculations - graph_net/plot_ESt.py: added mandatory macro/micro validation - graph_net/verify_macro_params.py: new independent verification script All code passes pre-commit checks, compiles successfully, and has been validated with real benchmark data. * refactor: rename macro to aggregated and improve code quality This commit refactors the evaluation metrics calculation code with the following improvements: 1. Terminology refactoring: macro -> aggregated - Rename macro_statistics.py to samples_statistics.py - Rename verify_macro_params.py to verify_aggregated_params.py - Update all variable and function names accordingly 2. Code structure improvements - Extract verification logic in plot_ESt.py into separate functions * compare_single_tolerance_level (12 lines) * print_verification_result (1 line) * verify_aggregated_micro_consistency (28 lines, meets ≤30 line requirement) - Refactor verify_aggregated_params.py to use functional programming style * Replace structured loops with list comprehensions * Use Counter for error type counting * Reduce multiple traversals to single pass where possible 3. Reduce function parameter coupling - calculate_beta: derive slowdown_speedups internally from correct_speedups - calculate_lambda: derive correct_count internally from correct_speedups - calculate_eta: derive statistics internally from correct_speedups 4. Decouple error type handling - calculate_pi: accept error_type_counts (dict) instead of hardcoded types - calculate_gamma: accept generic parameters (tolerance, get_pi, errno_tolerances) - Support user-defined error codes instead of hardcoded error types 5. Code quality improvements - Use explicit len() checks instead of implicit boolean conversion - Use modern Python type hints (list/tuple instead of typing.List/Tuple) - Improve code readability and maintainability All changes have been verified and pass pre-commit checks. * style: apply black formatting to samples_statistics.py and verify_aggregated_params.py * refactor: unify error type to errno mapping for better sorting - Replace error_type_counts (dict[str, int]) with errno2count (dict[int, int]) - Add get_errno_from_error_type() to map error type strings to errno (1, 2, 3) - Add get_error_type_from_errno() for reverse mapping when error type strings are needed - Update calculate_pi() to use errno2count and return dict[int, float] - Update calculate_all_aggregated_parameters() to use errno2count and errno_tolerance_thresholds - Update analysis_util.py and verify_aggregated_params.py to use errno2count - Improve code maintainability by using integer errno for sorting and comparison * refactor: split tolerance report generation * refactor: improve naming and semantics for ES calculation - Rename verify_es_match_at_tolerance to compare_aggregated_es_and_microscopic_es - Replace tolerance_level with tolerance parameter - Replace tolerance_threshold with atol/rtol to avoid confusion - Rename verify_aggregated_microscopic_consistency to get_verified_aggregated_es_values - Change return type to dict only (remove all_matched) - Rename verified_scores to verified_es_values - Replace micro with microscopic throughout - Rename check_sample_correctness to get_sample_correctness - Rename t1 variables to first_errno_tolerance - Rename es_components to es_constructor_params - Rename calculate_parameters_for_tolerance to calculate_es_constructor_params_for_tolerance - Rename custom_map to errno_tolerance_overrides - Rename errno_as_tolerances to errno2tolerance - Add enable_aggregation_mode command line option * feat: add aggregated ES(t) plotting and verification - Modified plot_ES_results to return fig, ax, all_x_coords for external plotting - Added manual plotting of aggregated ES(t) curves in main function - Both microscopic and aggregated curves are plotted on the same graph - Aggregated curves use dashed lines with square markers for distinction - All verification checks pass with floating-point precision differences (1.39e-17) * fix: move ax.legend outside aggregation condition block - Move ax.legend() outside the aggregation mode condition block - Ensure legend is always displayed regardless of aggregation mode - Fix issue where legend was missing when aggregation mode is disabled
1 parent 5133cb7 commit d698d66

File tree

4 files changed

+1047
-6
lines changed

4 files changed

+1047
-6
lines changed

graph_net/analysis_util.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def parse_logs_to_data(log_file: str) -> list:
164164
List of data dictionaries, each containing configuration, correctness,
165165
performance, and result information for a single model-compiler run.
166166
"""
167+
167168
try:
168169
with open(log_file, "r", encoding="utf-8") as f:
169170
lines = f.readlines()
@@ -330,10 +331,13 @@ def scan_all_folders(benchmark_path: str) -> dict:
330331
"""
331332
Unified entry point that supports log files and directories:
332333
- If benchmark_path is a log file (.log or .txt) → parse it directly and return data as a single curve.
334+
333335
- If benchmark_path is a directory → scan for .log and .txt files in the directory,
334336
each log file becomes a curve.
337+
335338
Returns dict[curve_name] -> list_of_samples
336339
"""
340+
337341
# Handle single log file
338342
if os.path.isfile(benchmark_path):
339343
print(f"Detected log file: '{benchmark_path}'")
@@ -648,3 +652,47 @@ def print_stat_info(
648652
print(f" - pi: {pi}")
649653

650654
return s_scores, s_scores_fake_degrad
655+
656+
657+
def check_sample_correctness(sample: dict, t_key: int) -> tuple[bool, str]:
658+
"""
659+
Check if a sample is correct at the given tolerance level.
660+
661+
Args:
662+
sample: Sample data dictionary
663+
t_key: Tolerance level
664+
665+
Returns:
666+
Tuple of (is_correct, fail_type)
667+
- is_correct: True if sample is correct at this tolerance
668+
- fail_type: Error type if not correct, None if correct
669+
"""
670+
performance_data = sample.get("performance", {})
671+
fail_type = performance_data.get("failure")
672+
673+
# If there's already a failure type, return it
674+
if fail_type is not None:
675+
return False, fail_type
676+
677+
# Check correctness based on datatype and tolerance
678+
datatype_data = performance_data.get("datatype", {})
679+
eager_dtypes = datatype_data.get("eager", [])
680+
compiled_dtypes = datatype_data.get("compiled", [])
681+
682+
# Check if datatypes match and are valid
683+
if not (eager_dtypes and eager_dtypes == compiled_dtypes and len(eager_dtypes) > 0):
684+
return False, "accuracy"
685+
686+
correctness_data = sample.get("correctness", {})
687+
output_count = len(correctness_data.get("[equal]", []))
688+
689+
if len(eager_dtypes) != output_count:
690+
return False, "accuracy"
691+
692+
# Check all outputs for correctness
693+
is_correct = all(
694+
get_correctness(eager_dtypes[i], t_key, correctness_data, i)
695+
for i in range(output_count)
696+
)
697+
698+
return is_correct, None if is_correct else "accuracy"

graph_net/plot_ESt.py

Lines changed: 281 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,163 @@
33
import numpy as np
44
import matplotlib.pyplot as plt
55
from graph_net import analysis_util
6+
from graph_net import verify_aggregated_params
7+
8+
9+
class ESScoresWrapper:
10+
"""Wrapper for es_scores dict to allow attribute assignment."""
11+
12+
def __init__(self, es_scores_dict):
13+
self._dict = es_scores_dict
14+
self._aggregated_results = {}
15+
16+
def items(self):
17+
return self._dict.items()
18+
19+
def __getitem__(self, key):
20+
return self._dict[key]
21+
22+
def __setitem__(self, key, value):
23+
self._dict[key] = value
24+
25+
26+
def es_result_checker(
27+
es_from_microscopic: float, es_from_macro: float, atol: float, rtol: float
28+
) -> bool:
29+
"""
30+
Check if ES(t) values from microscopic and macro calculations match.
31+
32+
Args:
33+
es_from_microscopic: ES(t) value from microscopic-level calculation
34+
es_from_macro: ES(t) value from aggregated-level calculation
35+
atol: Absolute tolerance for comparison
36+
rtol: Relative tolerance for comparison
37+
38+
Returns:
39+
True if values match within tolerance, False otherwise
40+
"""
41+
return np.allclose(es_from_microscopic, es_from_macro, rtol=rtol, atol=atol)
42+
43+
44+
def compare_aggregated_es_and_microscopic_es(
45+
tolerance: int,
46+
microscopic_es: float,
47+
aggregated_es: float | None,
48+
atol: float = 1e-3,
49+
rtol: float = 1e-3,
50+
) -> tuple[bool, float, float]:
51+
"""
52+
Compare ES(t) values from aggregated and microscopic calculations at a tolerance level.
53+
54+
Args:
55+
tolerance: Tolerance level t
56+
microscopic_es: ES(t) value from microscopic-level calculation
57+
aggregated_es: ES(t) value from aggregated-level calculation, or None if missing
58+
atol: Absolute tolerance for comparison
59+
rtol: Relative tolerance for comparison
60+
61+
Returns:
62+
Tuple of (is_matched, diff, relative_diff)
63+
"""
64+
if aggregated_es is None:
65+
return False, 0.0, 0.0
66+
67+
diff = abs(microscopic_es - aggregated_es)
68+
relative_diff = diff / max(abs(microscopic_es), abs(aggregated_es), 1e-10)
69+
is_matched = es_result_checker(microscopic_es, aggregated_es, atol, rtol)
70+
71+
return is_matched, diff, relative_diff
72+
73+
74+
def print_verification_result(
75+
tolerance: int,
76+
microscopic_es: float,
77+
aggregated_es: float | None,
78+
diff: float,
79+
relative_diff: float,
80+
is_matched: bool,
81+
) -> None:
82+
"""Print verification result for a single tolerance level."""
83+
if aggregated_es is None:
84+
print(f"ERROR: No aggregated result for t={tolerance}, cannot verify")
85+
elif is_matched:
86+
print(
87+
f"t={tolerance:3d}: MATCHED - Microscopic: {microscopic_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e}"
88+
)
89+
else:
90+
print(
91+
f"t={tolerance:3d}: MISMATCH - Microscopic: {microscopic_es:.6f}, Aggregated: {aggregated_es:.6f}, Diff: {diff:.2e} ({relative_diff*100:.4f}%)"
92+
)
93+
94+
95+
def get_verified_aggregated_es_values(es_scores: dict, folder_name: str) -> dict:
96+
"""
97+
Get verified ES(t) values by checking consistency between aggregated and microscopic-level calculations.
98+
99+
Args:
100+
es_scores: Dictionary of ES(t) scores from microscopic-level calculation
101+
folder_name: Name of the folder being verified
102+
103+
Returns:
104+
Dictionary of verified ES(t) values (only matched tolerance levels).
105+
106+
Raises:
107+
AssertionError: If aggregated and microscopic results do not match (fail-fast).
108+
"""
109+
aggregated_results = getattr(es_scores, "_aggregated_results", {})
110+
verified_es_values = {}
111+
mismatches = []
112+
113+
print(f"\n{'='*80}")
114+
print(f"Verifying Aggregated/Microscopic Consistency for '{folder_name}'")
115+
print(f"{'='*80}")
116+
117+
for tolerance, microscopic_es in es_scores.items():
118+
aggregated_es = aggregated_results.get(tolerance)
119+
is_matched, diff, relative_diff = compare_aggregated_es_and_microscopic_es(
120+
tolerance, microscopic_es, aggregated_es
121+
)
122+
123+
print_verification_result(
124+
tolerance,
125+
microscopic_es,
126+
aggregated_es,
127+
diff,
128+
relative_diff,
129+
is_matched,
130+
)
131+
132+
if aggregated_es is None:
133+
mismatches.append(
134+
f"t={tolerance}: Missing aggregated result (microscopic={microscopic_es:.6f})"
135+
)
136+
elif not is_matched:
137+
mismatches.append(
138+
f"t={tolerance}: Mismatch - Microscopic={microscopic_es:.6f}, "
139+
f"Aggregated={aggregated_es:.6f}, Diff={diff:.2e} ({relative_diff*100:.4f}%)"
140+
)
141+
else:
142+
verified_es_values[tolerance] = microscopic_es
143+
144+
if mismatches:
145+
error_msg = (
146+
f"\n{'='*80}\n"
147+
f"ERROR: Aggregated and microscopic results do not match for '{folder_name}'!\n"
148+
f"{'='*80}\n"
149+
f"Mismatches:\n"
150+
+ "\n".join(f" - {mismatch}" for mismatch in mismatches)
151+
+ f"\n\nCalculation validation failed. Please verify the calculation logic "
152+
f"using verify_aggregated_params.py\n"
153+
f"{'='*80}\n"
154+
)
155+
print(error_msg)
156+
raise AssertionError(error_msg)
157+
158+
print(
159+
f"\nSUCCESS: All aggregated and microscopic results match for '{folder_name}'."
160+
)
161+
print(f"{'='*80}\n")
162+
return verified_es_values
6163

7164

8165
def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace):
@@ -93,10 +250,7 @@ def plot_ES_results(s_scores: dict, cli_args: argparse.Namespace):
93250
ax.xaxis.grid(True, which="major", lw=0.7, ls=":", color="grey", alpha=0.5)
94251
ax.yaxis.grid(True, which="major", lw=0.7, ls=":", color="grey", alpha=0.5)
95252

96-
ax.legend(fontsize=16, loc="best")
97-
output_file = os.path.join(cli_args.output_dir, "ES_result.png")
98-
plt.savefig(output_file, dpi=300, bbox_inches="tight")
99-
print(f"\nComparison plot saved to {output_file}")
253+
return fig, ax, all_x_coords
100254

101255

102256
def main():
@@ -130,6 +284,18 @@ def main():
130284
default=0.1,
131285
help="Base penalty for severe errors (e.g., crashes, correctness failures).",
132286
)
287+
parser.add_argument(
288+
"--enable-aggregation-mode",
289+
action="store_true",
290+
help="Enable aggregation mode to verify aggregated/microscopic consistency. Default: enabled.",
291+
)
292+
parser.add_argument(
293+
"--disable-aggregation-mode",
294+
dest="enable_aggregation_mode",
295+
action="store_false",
296+
help="Disable aggregation mode verification.",
297+
)
298+
parser.set_defaults(enable_aggregation_mode=True)
133299
args = parser.parse_args()
134300

135301
# 1. Scan folders to get data
@@ -138,21 +304,130 @@ def main():
138304
print("No valid data found. Exiting.")
139305
return
140306

141-
# 2. Calculate scores for each curve
307+
# 2. Calculate scores for each curve and verify aggregated/microscopic consistency
142308
all_es_scores = {}
309+
all_aggregated_results = {}
310+
143311
for folder_name, samples in all_results.items():
144312
_, es_scores = analysis_util.calculate_s_scores(
145313
samples,
146314
folder_name,
147315
negative_speedup_penalty=args.negative_speedup_penalty,
148316
fpdb=args.fpdb,
149317
)
318+
319+
# Keep original behavior: assign es_scores directly
150320
all_es_scores[folder_name] = es_scores
151321

322+
# Verify aggregated/microscopic consistency if aggregation mode is enabled
323+
if args.enable_aggregation_mode:
324+
# Calculate aggregated results and attach to es_scores
325+
aggregated_results = (
326+
verify_aggregated_params.verify_es_constructor_params_across_tolerances(
327+
samples,
328+
folder_name,
329+
negative_speedup_penalty=args.negative_speedup_penalty,
330+
fpdb=args.fpdb,
331+
)
332+
)
333+
# Store aggregated results for plotting
334+
all_aggregated_results[folder_name] = aggregated_results
335+
336+
# Extract expected_es values and attach as _aggregated_results
337+
# Wrap es_scores to allow attribute assignment
338+
es_scores_wrapper = ESScoresWrapper(es_scores)
339+
es_scores_wrapper._aggregated_results = {
340+
tolerance: result["expected_es"]
341+
for tolerance, result in aggregated_results.items()
342+
}
343+
344+
# Fail-fast: raise AssertionError if validation fails
345+
verified_es_values = get_verified_aggregated_es_values(
346+
es_scores_wrapper, folder_name
347+
)
348+
all_es_scores[folder_name] = verified_es_values
349+
152350
# 3. Plot the results
153351
if any(all_es_scores.values()):
154352
os.makedirs(args.output_dir, exist_ok=True)
155-
plot_ES_results(all_es_scores, args)
353+
fig, ax, all_x_coords = plot_ES_results(all_es_scores, args)
354+
355+
# Manually add aggregated curves if available
356+
if args.enable_aggregation_mode and all_aggregated_results:
357+
prop_cycle = plt.rcParams["axes.prop_cycle"]
358+
colors = prop_cycle.by_key()["color"]
359+
360+
for idx, (folder_name, aggregated_results) in enumerate(
361+
all_aggregated_results.items()
362+
):
363+
if folder_name not in all_es_scores:
364+
continue
365+
366+
color = colors[idx % len(colors)]
367+
agg_plot_points = []
368+
for tolerance, result in aggregated_results.items():
369+
if isinstance(result, dict) and "expected_es" in result:
370+
agg_plot_points.append(
371+
{"x": tolerance, "y": result["expected_es"]}
372+
)
373+
374+
if agg_plot_points:
375+
agg_plot_points.sort(key=lambda p: p["x"])
376+
agg_x_vals = np.array([p["x"] for p in agg_plot_points])
377+
agg_y_vals = np.array([p["y"] for p in agg_plot_points])
378+
379+
agg_zero_index = (
380+
np.where(agg_x_vals == 0)[0][0] if 0 in agg_x_vals else None
381+
)
382+
383+
if agg_zero_index is not None:
384+
ax.plot(
385+
agg_x_vals[: agg_zero_index + 1],
386+
agg_y_vals[: agg_zero_index + 1],
387+
"s--",
388+
color=color,
389+
label=f"{folder_name} (aggregated)",
390+
linewidth=2,
391+
markersize=6,
392+
alpha=0.7,
393+
)
394+
ax.plot(
395+
agg_x_vals[agg_zero_index:],
396+
agg_y_vals[agg_zero_index:],
397+
"s--",
398+
color=color,
399+
linewidth=2,
400+
markersize=6,
401+
drawstyle="steps-post",
402+
alpha=0.7,
403+
)
404+
else:
405+
ax.plot(
406+
agg_x_vals,
407+
agg_y_vals,
408+
"s--",
409+
color=color,
410+
label=f"{folder_name} (aggregated)",
411+
linewidth=2,
412+
markersize=6,
413+
alpha=0.7,
414+
)
415+
416+
# Update x-axis range if needed
417+
if all_x_coords:
418+
for folder_name, aggregated_results in all_aggregated_results.items():
419+
for tolerance in aggregated_results.keys():
420+
all_x_coords.append(tolerance)
421+
x_min = int(np.floor(min(all_x_coords)))
422+
x_max = int(np.ceil(max(all_x_coords)))
423+
ax.set_xticks(np.arange(x_min, x_max + 1))
424+
425+
# Always show legend (whether aggregated curves are added or not)
426+
ax.legend(fontsize=16, loc="best")
427+
428+
output_file = os.path.join(args.output_dir, "ES_result.png")
429+
plt.savefig(output_file, dpi=300, bbox_inches="tight")
430+
print(f"\nComparison plot saved to {output_file}")
156431
else:
157432
print("No ES(t) scores were calculated. Skipping plot generation.")
158433

0 commit comments

Comments
 (0)