Skip to content

Commit 732c7a5

Browse files
committed
Improve the script
1 parent 736ab8b commit 732c7a5

File tree

3 files changed

+459
-74
lines changed

3 files changed

+459
-74
lines changed

pai-audio-evals/main.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,44 @@
11
import random
2+
import re
23
from dataclasses import dataclass
4+
from functools import partial
35
from pathlib import Path
46

57
import logfire
8+
from nltk import edit_distance
69
from pydantic import TypeAdapter
7-
from pydantic_ai import Agent, BinaryContent
10+
from pydantic_ai import Agent, BinaryContent, AudioUrl
11+
from pydantic_evals import Dataset, Case
12+
from pydantic_evals.evaluators import Evaluator, EvaluatorContext, EvaluatorOutput
813

9-
logfire.configure(service_name='pai-audio-evals')
10-
logfire.instrument_pydantic_ai()
11-
12-
this_dir = Path(__file__).parent
13-
assets = this_dir / 'assets'
1414

15+
@dataclass
16+
class EditSimilarity(Evaluator[object, str, object]):
17+
def evaluate(self, ctx: EvaluatorContext[object, str, object]) -> EvaluatorOutput:
18+
if ctx.expected_output is None:
19+
return {} # no metric
20+
actual_tokens = re.sub(r'[^a-z0-9\s]', '', ctx.output.lower()).split()
21+
expected_tokens = re.sub(r'[^a-z0-9\s]', '', ctx.expected_output.lower()).split()
22+
distance = edit_distance(actual_tokens, expected_tokens)
23+
normalized_distance = distance / max(len(actual_tokens), len(expected_tokens))
24+
return 1 - normalized_distance
1525

16-
def levenshtein_distance(s1: str, s2: str) -> int:
17-
if len(s1) < len(s2):
18-
return levenshtein_distance(s2, s1)
19-
if len(s2) == 0:
20-
return len(s1)
2126

22-
previous_row = range(len(s2) + 1)
23-
for i, c1 in enumerate(s1):
24-
current_row = [i + 1]
25-
for j, c2 in enumerate(s2):
26-
insertions = previous_row[j + 1] + 1
27-
deletions = current_row[j] + 1
28-
substitutions = previous_row[j] + (c1 != c2)
29-
current_row.append(min(insertions, deletions, substitutions))
30-
previous_row = current_row
27+
logfire.configure(service_name='pai-audio-evals', scrubbing=False, console=False)
28+
logfire.instrument_pydantic_ai()
3129

32-
return previous_row[-1]
30+
this_dir = Path(__file__).parent
31+
assets = this_dir / 'assets'
3332

3433

3534
@dataclass
3635
class AudioFile:
3736
file: str
3837
text: str
3938

39+
def audio_url(self) -> AudioUrl:
40+
return AudioUrl(f'https://smokeshow.helpmanual.io/4l1l1s0s6q4741012x1w/{self.file}')
41+
4042
def binary_content(self) -> BinaryContent:
4143
path = assets / self.file
4244
return BinaryContent(data=path.read_bytes(), media_type='audio/mpeg')
@@ -45,12 +47,18 @@ def binary_content(self) -> BinaryContent:
4547
files_schema = TypeAdapter(list[AudioFile])
4648
files = files_schema.validate_json((this_dir / 'assets.json').read_bytes())
4749
random.shuffle(files)
50+
4851
audio_agent = Agent(instructions='return the transcription only, no prefix or quotes')
52+
dataset = Dataset(
53+
cases=[Case(name=file.file, inputs=file.audio_url(), expected_output=file.text) for file in files],
54+
evaluators=[EditSimilarity()],
55+
)
56+
57+
58+
async def task(audio_url: AudioUrl, model: str) -> str:
59+
return (await audio_agent.run(['transcribe', audio_url], model=model)).output
60+
4961

50-
for audio_file in files[:3]:
51-
with logfire.span('Transcribing audio {audio_file.text!r}', audio_file=audio_file):
52-
model_distances: list[tuple[str, int]] = []
53-
for model in 'gpt-4o-audio-preview', 'gpt-4o-mini-audio-preview', 'google-vertex:gemini-2.0-flash':
54-
result = audio_agent.run_sync(['transcribe', audio_file.binary_content()], model=model)
55-
model_distances.append((model, levenshtein_distance(audio_file.text, result.output)))
56-
logfire.info(f'{model_distances}')
62+
with logfire.span('Compare models'):
63+
for model in 'gpt-4o-audio-preview', 'gpt-4o-mini-audio-preview', 'google-vertex:gemini-2.0-flash':
64+
dataset.evaluate_sync(partial(task, model=model), name=model, max_concurrency=10)

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,13 @@ dependencies = [
99
"devtools>=0.12.2",
1010
"fastapi>=0.115.14",
1111
"logfire[asyncpg,fastapi,httpx]>=3.21.1",
12-
"pydantic-ai>=0.3.6",
12+
"pydantic-ai>=0.4.5",
13+
"nltk>=3.9.1",
1314
]
1415

16+
[tool.uv.sources]
17+
pydantic-ai = { git = "https://github.com/pydantic/pydantic-ai.git", rev = "0f46928bd07bc1a9f89c1d72c76cd2a86d52d489" }
18+
1519
[dependency-groups]
1620
dev = ["ruff>=0.12.2", "asyncpg-stubs>=0.30.2"]
1721

0 commit comments

Comments
 (0)