Skip to content

Commit 627a038

Browse files
committed
required changes
1 parent e65b2be commit 627a038

File tree

3 files changed

+32
-35
lines changed

3 files changed

+32
-35
lines changed

method_comparison/peft_bench/data.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,17 +42,14 @@ def load_test_prompts(config: dict) -> dict[str, list[str]]:
4242
# Use the specified prompts file or fall back to default
4343
prompts_file = config.get("prompts_file", DEFAULT_PROMPTS_PATH)
4444

45-
try:
46-
with open(prompts_file) as f:
47-
prompts = json.load(f)
45+
with open(prompts_file) as f:
46+
prompts = json.load(f)
4847

49-
# Apply textwrap.dedent to remove leading spaces from multiline prompts
50-
for category, prompt_list in prompts.items():
51-
prompts[category] = [textwrap.dedent(prompt) for prompt in prompt_list]
48+
# Apply textwrap.dedent to remove leading spaces from multiline prompts
49+
for category, prompt_list in prompts.items():
50+
prompts[category] = [textwrap.dedent(prompt) for prompt in prompt_list]
5251

53-
return prompts
54-
except Exception as e:
55-
raise ValueError(f"Error loading prompts from {prompts_file}: {e}")
52+
return prompts
5653

5754

5855
def truncate_prompt_for_model(

method_comparison/peft_bench/run.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -315,26 +315,21 @@ def main():
315315
# Configure print function based on verbosity
316316
print_fn = print if args.verbose else lambda *args, **kwargs: None
317317

318-
try:
319-
# Validate experiment path and load configs
320-
experiment_name, benchmark_config, peft_config = validate_experiment_path(args.experiment_path)
321-
322-
print_fn(f"Running benchmark for experiment: {experiment_name}")
318+
# Validate experiment path and load configs
319+
experiment_name, benchmark_config, peft_config = validate_experiment_path(args.experiment_path)
323320

324-
# Run the benchmark
325-
result = run_benchmark(
326-
benchmark_config=benchmark_config,
327-
experiment_name=experiment_name,
328-
experiment_path=args.experiment_path,
329-
print_fn=print_fn,
330-
)
321+
print_fn(f"Running benchmark for experiment: {experiment_name}")
331322

332-
# Log and save results
333-
log_results(experiment_name, result, print_fn=print)
323+
# Run the benchmark
324+
result = run_benchmark(
325+
benchmark_config=benchmark_config,
326+
experiment_name=experiment_name,
327+
experiment_path=args.experiment_path,
328+
print_fn=print_fn,
329+
)
334330

335-
except Exception as e:
336-
print(f"Error: {e}", file=sys.stderr)
337-
return 1
331+
# Log and save results
332+
log_results(experiment_name, result, print_fn=print)
338333

339334
return 0
340335

method_comparison/peft_bench/utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,19 @@
1919
import datetime
2020
import json
2121
import os
22-
import time
2322
import uuid
2423
from dataclasses import asdict, dataclass, field
2524
from enum import Enum
2625
from typing import Any, Callable, Optional
2726

28-
import psutil # You might need to install this: pip install psutil
27+
import psutil
2928
import torch
3029
from peft import PeftConfig
3130

3231

3332
# Constants
3433
FILE_NAME_BENCHMARK_PARAMS = "benchmark_params.json"
34+
FILE_NAME_DEFAULT_CONFIG = "default_config.json"
3535

3636
# Main paths for storing results
3737
RESULT_PATH = os.path.join(os.path.dirname(__file__), "results")
@@ -73,7 +73,7 @@ def __post_init__(self):
7373
"""Initialize structured data format."""
7474
# Default run_info
7575
self.run_info = {
76-
"timestamp": datetime.datetime.now().isoformat(),
76+
"timestamp": datetime.datetime.now(tz=datetime.timezone.utc).isoformat(),
7777
"duration": 0.0,
7878
"status": self.status.value,
7979
"hardware": {
@@ -347,7 +347,7 @@ def get_variant_configs(self) -> list["BenchmarkConfig"]:
347347

348348
def generate_experiment_id() -> str:
349349
"""Generate a unique experiment ID."""
350-
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + "_" + uuid.uuid4().hex[:8]
350+
return datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y%m%d_%H%M%S")
351351

352352

353353
def validate_experiment_path(path: str) -> tuple[str, "BenchmarkConfig", Any]:
@@ -360,11 +360,16 @@ def validate_experiment_path(path: str) -> tuple[str, "BenchmarkConfig", Any]:
360360

361361
# Check for benchmark params file
362362
benchmark_params_path = os.path.join(path, FILE_NAME_BENCHMARK_PARAMS)
363-
if not os.path.exists(benchmark_params_path):
364-
raise FileNotFoundError(f"Benchmark params not found: {benchmark_params_path}")
365-
366-
# Load benchmark config
367-
benchmark_config = BenchmarkConfig.from_json(benchmark_params_path)
363+
default_config_path = os.path.join(os.path.dirname(__file__), FILE_NAME_DEFAULT_CONFIG)
364+
365+
# Use benchmark_params.json if exists, otherwise use default config
366+
if os.path.exists(benchmark_params_path):
367+
benchmark_config = BenchmarkConfig.from_json(benchmark_params_path)
368+
elif os.path.exists(default_config_path):
369+
print(f"No benchmark_params.json found in {path}, using default configuration")
370+
benchmark_config = BenchmarkConfig.from_json(default_config_path)
371+
else:
372+
raise FileNotFoundError(f"Neither benchmark_params.json nor default_config.json found")
368373

369374
# Try to load PEFT config
370375
try:

0 commit comments

Comments
 (0)