547 lines
20 KiB
Python
Executable File
547 lines
20 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
from difflib import SequenceMatcher
|
|
import json
|
|
import re
|
|
import statistics
|
|
import unicodedata
|
|
from pathlib import Path
|
|
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
|
|
|
TIMING_PATTERNS = {
|
|
"model_load_ms": re.compile(r"load time\s*=\s*([0-9]+(?:\.[0-9]+)?)\s*ms"),
|
|
"mel_ms": re.compile(r"mel time\s*=\s*([0-9]+(?:\.[0-9]+)?)\s*ms"),
|
|
"sample_ms": re.compile(r"sample time\s*=\s*([0-9]+(?:\.[0-9]+)?)\s*ms"),
|
|
"encode_ms": re.compile(r"encode time\s*=\s*([0-9]+(?:\.[0-9]+)?)\s*ms"),
|
|
"decode_ms": re.compile(r"decode time\s*=\s*([0-9]+(?:\.[0-9]+)?)\s*ms"),
|
|
"batchd_ms": re.compile(r"batchd time\s*=\s*([0-9]+(?:\.[0-9]+)?)\s*ms"),
|
|
"prompt_ms": re.compile(r"prompt time\s*=\s*([0-9]+(?:\.[0-9]+)?)\s*ms"),
|
|
"full_runtime_ms": re.compile(r"total time\s*=\s*([0-9]+(?:\.[0-9]+)?)\s*ms"),
|
|
}
|
|
|
|
TOKEN_RATE_PATTERNS = [
|
|
re.compile(r"tokens?\s*/\s*s(?:ec(?:ond)?)?\s*[:=]\s*([0-9]+(?:\.[0-9]+)?)", re.IGNORECASE),
|
|
re.compile(r"tokens?\s+per\s+second\s*[:=]\s*([0-9]+(?:\.[0-9]+)?)", re.IGNORECASE),
|
|
]
|
|
|
|
SEGMENT_LINE_PATTERN = re.compile(
|
|
r"^\[\d{2}:\d{2}:\d{2}\.\d{3}\s+-->\s+\d{2}:\d{2}:\d{2}\.\d{3}\]\s*(.*)$"
|
|
)
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
p = argparse.ArgumentParser(description="Parse benchmark logs and compute run statistics.")
|
|
p.add_argument("--run-dir", required=True, help="Path to benchmark results run directory.")
|
|
p.add_argument(
|
|
"--refs-dir",
|
|
default=None,
|
|
help="Directory containing reference transcripts {short,medium,long}.txt (default: benchmark/references).",
|
|
)
|
|
p.add_argument("--max-wer", type=float, default=0.02, help="Max allowed median WER per audio.")
|
|
p.add_argument("--max-cer", type=float, default=0.02, help="Max allowed median CER per audio.")
|
|
p.add_argument(
|
|
"--enforce-correctness",
|
|
action="store_true",
|
|
help="Exit non-zero if any referenced audio exceeds max WER/CER or if references are missing.",
|
|
)
|
|
return p.parse_args()
|
|
|
|
|
|
def to_float_or_none(value: Optional[float]) -> Optional[float]:
|
|
if value is None:
|
|
return None
|
|
return float(value)
|
|
|
|
|
|
def parse_log_metrics(log_text: str) -> Dict[str, Optional[float]]:
|
|
out: Dict[str, Optional[float]] = {k: None for k in TIMING_PATTERNS.keys()}
|
|
|
|
for key, pat in TIMING_PATTERNS.items():
|
|
m = pat.search(log_text)
|
|
out[key] = float(m.group(1)) if m else None
|
|
|
|
token_rate = None
|
|
for pat in TOKEN_RATE_PATTERNS:
|
|
m = pat.search(log_text)
|
|
if m:
|
|
token_rate = float(m.group(1))
|
|
break
|
|
out["tokens_per_second"] = token_rate
|
|
|
|
return out
|
|
|
|
|
|
def extract_transcript(log_text: str) -> str:
|
|
lines: List[str] = []
|
|
for raw_line in log_text.splitlines():
|
|
line = raw_line.strip()
|
|
m = SEGMENT_LINE_PATTERN.match(line)
|
|
if not m:
|
|
continue
|
|
text = m.group(1).strip()
|
|
if text:
|
|
lines.append(text)
|
|
return " ".join(lines).strip()
|
|
|
|
|
|
def normalize_text(text: str) -> str:
|
|
text = unicodedata.normalize("NFKC", text).lower()
|
|
text = text.replace("_", " ")
|
|
text = re.sub(r"[^\w\s']", " ", text, flags=re.UNICODE)
|
|
text = re.sub(r"\s+", " ", text).strip()
|
|
return text
|
|
|
|
|
|
def levenshtein_distance(a: Sequence[str], b: Sequence[str]) -> int:
|
|
if not a:
|
|
return len(b)
|
|
if not b:
|
|
return len(a)
|
|
|
|
prev = list(range(len(b) + 1))
|
|
for i, ca in enumerate(a, start=1):
|
|
curr = [i]
|
|
for j, cb in enumerate(b, start=1):
|
|
cost = 0 if ca == cb else 1
|
|
curr.append(
|
|
min(
|
|
prev[j] + 1, # delete
|
|
curr[j - 1] + 1, # insert
|
|
prev[j - 1] + cost, # substitute
|
|
)
|
|
)
|
|
prev = curr
|
|
return prev[-1]
|
|
|
|
|
|
def error_rate(reference: Sequence[str], hypothesis: Sequence[str]) -> float:
|
|
if len(reference) == 0:
|
|
return 0.0 if len(hypothesis) == 0 else 1.0
|
|
|
|
# Exact DP is expensive for long transcripts. Use exact distance for small/medium
|
|
# inputs and a fast similarity-based fallback for long inputs.
|
|
if len(reference) * max(1, len(hypothesis)) <= 2_000_000:
|
|
dist = levenshtein_distance(reference, hypothesis)
|
|
return dist / float(len(reference))
|
|
|
|
ratio = SequenceMatcher(a=list(reference), b=list(hypothesis), autojunk=False).ratio()
|
|
return max(0.0, min(1.0, 1.0 - ratio))
|
|
|
|
|
|
def compute_wer_cer(reference_text: str, hypothesis_text: str) -> Tuple[float, float]:
|
|
ref_norm = normalize_text(reference_text)
|
|
hyp_norm = normalize_text(hypothesis_text)
|
|
|
|
ref_words = ref_norm.split() if ref_norm else []
|
|
hyp_words = hyp_norm.split() if hyp_norm else []
|
|
ref_chars = list(ref_norm.replace(" ", ""))
|
|
hyp_chars = list(hyp_norm.replace(" ", ""))
|
|
|
|
wer = error_rate(ref_words, hyp_words)
|
|
cer = error_rate(ref_chars, hyp_chars)
|
|
|
|
return wer, cer
|
|
|
|
|
|
def safe_mean(values: Iterable[float]) -> Optional[float]:
|
|
vals = list(values)
|
|
return statistics.mean(vals) if vals else None
|
|
|
|
|
|
def safe_median(values: Iterable[float]) -> Optional[float]:
|
|
vals = list(values)
|
|
return statistics.median(vals) if vals else None
|
|
|
|
|
|
def safe_min(values: Iterable[float]) -> Optional[float]:
|
|
vals = list(values)
|
|
return min(vals) if vals else None
|
|
|
|
|
|
def safe_max(values: Iterable[float]) -> Optional[float]:
|
|
vals = list(values)
|
|
return max(vals) if vals else None
|
|
|
|
|
|
def safe_stdev(values: Iterable[float]) -> Optional[float]:
|
|
vals = list(values)
|
|
if not vals:
|
|
return None
|
|
if len(vals) == 1:
|
|
return 0.0
|
|
return statistics.stdev(vals)
|
|
|
|
|
|
def stats_block(values: Iterable[float]) -> Dict[str, Optional[float]]:
|
|
vals = list(values)
|
|
return {
|
|
"mean": safe_mean(vals),
|
|
"median": safe_median(vals),
|
|
"min": safe_min(vals),
|
|
"max": safe_max(vals),
|
|
"std_dev": safe_stdev(vals),
|
|
}
|
|
|
|
|
|
def fmt(value: Optional[float], suffix: str = "", decimals: int = 3) -> str:
|
|
if value is None:
|
|
return "NA"
|
|
return f"{value:.{decimals}f}{suffix}"
|
|
|
|
|
|
def write_runs_csv(path: Path, rows: List[Dict[str, object]]) -> None:
|
|
if not rows:
|
|
path.write_text("", encoding="utf-8")
|
|
return
|
|
|
|
fieldnames = [
|
|
"variant",
|
|
"model",
|
|
"audio_key",
|
|
"audio_length_s",
|
|
"run_kind",
|
|
"run_index",
|
|
"wall_clock_runtime_s",
|
|
"tokens_per_second",
|
|
"audio_seconds_per_second",
|
|
"model_load_ms",
|
|
"first_inference_latency_s",
|
|
"full_runtime_ms",
|
|
"mel_ms",
|
|
"sample_ms",
|
|
"encode_ms",
|
|
"decode_ms",
|
|
"batchd_ms",
|
|
"prompt_ms",
|
|
"transcript_word_error_rate",
|
|
"transcript_char_error_rate",
|
|
"reference_present",
|
|
"reference_path",
|
|
"transcript_path",
|
|
"metal_kernel_runtime_ms",
|
|
"cpu_orchestration_ms",
|
|
"log_path",
|
|
]
|
|
|
|
with path.open("w", newline="", encoding="utf-8") as f:
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
run_dir = Path(args.run_dir).resolve()
|
|
if not run_dir.is_dir():
|
|
raise SystemExit(f"Run directory does not exist: {run_dir}")
|
|
|
|
if args.refs_dir:
|
|
refs_dir = Path(args.refs_dir).resolve()
|
|
else:
|
|
refs_dir = run_dir.parent.parent / "references"
|
|
|
|
config_path = run_dir / "config.json"
|
|
if not config_path.is_file():
|
|
raise SystemExit(f"Missing config file: {config_path}")
|
|
|
|
config = json.loads(config_path.read_text(encoding="utf-8"))
|
|
variant = config["variant"]
|
|
model = Path(config["model"]["rel_path"]).name
|
|
|
|
run_meta_paths = sorted((run_dir / "raw").glob("*/*run_*.meta.json"))
|
|
if not run_meta_paths:
|
|
raise SystemExit(f"No measured run metadata files found under {(run_dir / 'raw')}")
|
|
|
|
reference_texts: Dict[str, Optional[str]] = {}
|
|
reference_paths: Dict[str, Path] = {}
|
|
missing_references: List[str] = []
|
|
for audio_key in ("short", "medium", "long"):
|
|
ref_path = refs_dir / f"{audio_key}.txt"
|
|
reference_paths[audio_key] = ref_path
|
|
if ref_path.is_file():
|
|
reference_texts[audio_key] = ref_path.read_text(encoding="utf-8", errors="replace").strip()
|
|
else:
|
|
reference_texts[audio_key] = None
|
|
missing_references.append(audio_key)
|
|
|
|
run_rows: List[Dict[str, object]] = []
|
|
for meta_path in run_meta_paths:
|
|
meta = json.loads(meta_path.read_text(encoding="utf-8"))
|
|
if meta.get("run_kind") != "measured":
|
|
continue
|
|
|
|
log_path = Path(meta["log_path"])
|
|
log_text = log_path.read_text(encoding="utf-8", errors="replace")
|
|
parsed = parse_log_metrics(log_text)
|
|
transcript_text = extract_transcript(log_text)
|
|
transcript_path = log_path.with_suffix(".transcript.txt")
|
|
transcript_path.write_text(transcript_text + "\n", encoding="utf-8")
|
|
|
|
audio_key = str(meta["audio_key"])
|
|
ref_text = reference_texts.get(audio_key)
|
|
ref_present = ref_text is not None
|
|
wer = None
|
|
cer = None
|
|
if ref_present:
|
|
wer, cer = compute_wer_cer(ref_text, transcript_text)
|
|
|
|
wall_s = float(meta["wall_clock_runtime_s"])
|
|
audio_len_s = float(meta["audio_duration_s"])
|
|
audio_s_per_s = audio_len_s / wall_s if wall_s > 0 else None
|
|
full_runtime_ms = parsed["full_runtime_ms"]
|
|
if full_runtime_ms is None:
|
|
full_runtime_ms = wall_s * 1000.0
|
|
|
|
run_rows.append(
|
|
{
|
|
"variant": variant,
|
|
"model": model,
|
|
"audio_key": audio_key,
|
|
"audio_length_s": audio_len_s,
|
|
"run_kind": meta["run_kind"],
|
|
"run_index": int(meta["run_index"]),
|
|
"wall_clock_runtime_s": wall_s,
|
|
"tokens_per_second": to_float_or_none(parsed["tokens_per_second"]),
|
|
"audio_seconds_per_second": to_float_or_none(audio_s_per_s),
|
|
"model_load_ms": to_float_or_none(parsed["model_load_ms"]),
|
|
"first_inference_latency_s": to_float_or_none(meta.get("first_inference_latency_s")),
|
|
"full_runtime_ms": to_float_or_none(full_runtime_ms),
|
|
"mel_ms": to_float_or_none(parsed["mel_ms"]),
|
|
"sample_ms": to_float_or_none(parsed["sample_ms"]),
|
|
"encode_ms": to_float_or_none(parsed["encode_ms"]),
|
|
"decode_ms": to_float_or_none(parsed["decode_ms"]),
|
|
"batchd_ms": to_float_or_none(parsed["batchd_ms"]),
|
|
"prompt_ms": to_float_or_none(parsed["prompt_ms"]),
|
|
"transcript_word_error_rate": to_float_or_none(wer),
|
|
"transcript_char_error_rate": to_float_or_none(cer),
|
|
"reference_present": ref_present,
|
|
"reference_path": str(reference_paths[audio_key]),
|
|
"transcript_path": str(transcript_path),
|
|
"metal_kernel_runtime_ms": None,
|
|
"cpu_orchestration_ms": None,
|
|
"log_path": str(log_path),
|
|
}
|
|
)
|
|
|
|
if not run_rows:
|
|
raise SystemExit("No measured runs were parsed.")
|
|
|
|
runs_csv_path = run_dir / "runs.csv"
|
|
write_runs_csv(runs_csv_path, run_rows)
|
|
|
|
by_audio: Dict[str, List[Dict[str, object]]] = {}
|
|
for row in run_rows:
|
|
by_audio.setdefault(str(row["audio_key"]), []).append(row)
|
|
|
|
summary_rows: List[Dict[str, object]] = []
|
|
md_lines = [
|
|
"| Variant | Model | Audio Length | Runs | Init Mean | First Inference Mean | Runtime Median | Throughput | Std Dev | Notes |",
|
|
"|---|---|---:|---:|---:|---:|---:|---:|---:|---|",
|
|
]
|
|
overall_correctness_pass = True
|
|
|
|
for audio_key in ("short", "medium", "long"):
|
|
rows = sorted(by_audio.get(audio_key, []), key=lambda r: int(r["run_index"]))
|
|
if not rows:
|
|
continue
|
|
|
|
audio_length_s = float(rows[0]["audio_length_s"])
|
|
|
|
wall_values = [float(r["wall_clock_runtime_s"]) for r in rows if r["wall_clock_runtime_s"] is not None]
|
|
load_values = [float(r["model_load_ms"]) for r in rows if r["model_load_ms"] is not None]
|
|
first_values = [float(r["first_inference_latency_s"]) for r in rows if r["first_inference_latency_s"] is not None]
|
|
throughput_values = [float(r["audio_seconds_per_second"]) for r in rows if r["audio_seconds_per_second"] is not None]
|
|
token_values = [float(r["tokens_per_second"]) for r in rows if r["tokens_per_second"] is not None]
|
|
full_runtime_values = [float(r["full_runtime_ms"]) for r in rows if r["full_runtime_ms"] is not None]
|
|
encode_values = [float(r["encode_ms"]) for r in rows if r["encode_ms"] is not None]
|
|
decode_values = [float(r["decode_ms"]) for r in rows if r["decode_ms"] is not None]
|
|
wer_values = [float(r["transcript_word_error_rate"]) for r in rows if r["transcript_word_error_rate"] is not None]
|
|
cer_values = [float(r["transcript_char_error_rate"]) for r in rows if r["transcript_char_error_rate"] is not None]
|
|
reference_present = bool(rows[0]["reference_present"])
|
|
|
|
runtime_stats = stats_block(wall_values)
|
|
full_runtime_stats = stats_block(full_runtime_values)
|
|
load_stats = stats_block(load_values)
|
|
first_stats = stats_block(first_values)
|
|
throughput_stats = stats_block(throughput_values)
|
|
token_stats = stats_block(token_values)
|
|
encode_stats = stats_block(encode_values)
|
|
decode_stats = stats_block(decode_values)
|
|
wer_stats = stats_block(wer_values)
|
|
cer_stats = stats_block(cer_values)
|
|
|
|
notes_parts: List[str] = []
|
|
if token_stats["mean"] is None:
|
|
notes_parts.append("tokens/s unavailable")
|
|
else:
|
|
notes_parts.append(f"tokens/s mean={token_stats['mean']:.3f}")
|
|
if encode_stats["mean"] is not None:
|
|
notes_parts.append(f"encode mean={encode_stats['mean']:.2f} ms")
|
|
if decode_stats["mean"] is not None:
|
|
notes_parts.append(f"decode mean={decode_stats['mean']:.2f} ms")
|
|
if reference_present:
|
|
notes_parts.append(f"wer median={fmt(wer_stats['median'], '', 4)}")
|
|
notes_parts.append(f"cer median={fmt(cer_stats['median'], '', 4)}")
|
|
else:
|
|
notes_parts.append("reference missing")
|
|
|
|
correctness_pass: Optional[bool]
|
|
if not reference_present:
|
|
correctness_pass = False if args.enforce_correctness else None
|
|
else:
|
|
correctness_pass = (
|
|
wer_stats["median"] is not None
|
|
and cer_stats["median"] is not None
|
|
and wer_stats["median"] <= args.max_wer
|
|
and cer_stats["median"] <= args.max_cer
|
|
)
|
|
if correctness_pass is False:
|
|
overall_correctness_pass = False
|
|
|
|
notes = "; ".join(notes_parts)
|
|
|
|
summary_rows.append(
|
|
{
|
|
"variant": variant,
|
|
"model": model,
|
|
"audio_key": audio_key,
|
|
"audio_length_s": audio_length_s,
|
|
"runs": len(rows),
|
|
"model_load_ms": load_stats,
|
|
"first_inference_latency_s": first_stats,
|
|
"wall_clock_runtime_s": runtime_stats,
|
|
"full_runtime_ms": full_runtime_stats,
|
|
"throughput_audio_seconds_per_second": throughput_stats,
|
|
"tokens_per_second": token_stats,
|
|
"encode_ms": encode_stats,
|
|
"decode_ms": decode_stats,
|
|
"wer": wer_stats,
|
|
"cer": cer_stats,
|
|
"reference_present": reference_present,
|
|
"correctness_pass": correctness_pass,
|
|
"notes": notes,
|
|
}
|
|
)
|
|
|
|
md_lines.append(
|
|
"| "
|
|
+ " | ".join(
|
|
[
|
|
variant,
|
|
model,
|
|
f"{audio_length_s:.3f}s",
|
|
str(len(rows)),
|
|
fmt(load_stats["mean"], " ms", decimals=2),
|
|
fmt(first_stats["mean"], " s", decimals=3),
|
|
fmt(runtime_stats["median"], " s", decimals=3),
|
|
fmt(throughput_stats["mean"], " audio-s/s", decimals=3),
|
|
fmt(runtime_stats["std_dev"], " s", decimals=3),
|
|
notes,
|
|
]
|
|
)
|
|
+ " |"
|
|
)
|
|
|
|
summary_csv_path = run_dir / "summary.csv"
|
|
with summary_csv_path.open("w", newline="", encoding="utf-8") as f:
|
|
fieldnames = [
|
|
"variant",
|
|
"model",
|
|
"audio_key",
|
|
"audio_length_s",
|
|
"runs",
|
|
"init_mean_ms",
|
|
"first_inference_mean_s",
|
|
"runtime_median_s",
|
|
"throughput_mean_audio_s_per_s",
|
|
"runtime_std_dev_s",
|
|
"reference_present",
|
|
"wer_median",
|
|
"cer_median",
|
|
"max_wer",
|
|
"max_cer",
|
|
"correctness_pass",
|
|
"notes",
|
|
]
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
|
writer.writeheader()
|
|
for row in summary_rows:
|
|
writer.writerow(
|
|
{
|
|
"variant": row["variant"],
|
|
"model": row["model"],
|
|
"audio_key": row["audio_key"],
|
|
"audio_length_s": row["audio_length_s"],
|
|
"runs": row["runs"],
|
|
"init_mean_ms": row["model_load_ms"]["mean"],
|
|
"first_inference_mean_s": row["first_inference_latency_s"]["mean"],
|
|
"runtime_median_s": row["wall_clock_runtime_s"]["median"],
|
|
"throughput_mean_audio_s_per_s": row["throughput_audio_seconds_per_second"]["mean"],
|
|
"runtime_std_dev_s": row["wall_clock_runtime_s"]["std_dev"],
|
|
"reference_present": row["reference_present"],
|
|
"wer_median": row["wer"]["median"],
|
|
"cer_median": row["cer"]["median"],
|
|
"max_wer": args.max_wer,
|
|
"max_cer": args.max_cer,
|
|
"correctness_pass": row["correctness_pass"],
|
|
"notes": row["notes"],
|
|
}
|
|
)
|
|
|
|
summary_json_path = run_dir / "summary.json"
|
|
summary_json_path.write_text(json.dumps(summary_rows, indent=2) + "\n", encoding="utf-8")
|
|
|
|
correctness_json_path = run_dir / "correctness.json"
|
|
correctness_report = {
|
|
"refs_dir": str(refs_dir),
|
|
"max_wer": args.max_wer,
|
|
"max_cer": args.max_cer,
|
|
"enforce_correctness": args.enforce_correctness,
|
|
"missing_references": missing_references,
|
|
"overall_correctness_pass": overall_correctness_pass and (
|
|
not args.enforce_correctness or len(missing_references) == 0
|
|
),
|
|
"audios": [
|
|
{
|
|
"audio_key": row["audio_key"],
|
|
"reference_present": row["reference_present"],
|
|
"wer_median": row["wer"]["median"],
|
|
"cer_median": row["cer"]["median"],
|
|
"correctness_pass": row["correctness_pass"],
|
|
}
|
|
for row in summary_rows
|
|
],
|
|
}
|
|
correctness_json_path.write_text(json.dumps(correctness_report, indent=2) + "\n", encoding="utf-8")
|
|
|
|
summary_md_path = run_dir / "summary.md"
|
|
summary_md_path.write_text("\n".join(md_lines) + "\n", encoding="utf-8")
|
|
|
|
print(f"Wrote: {runs_csv_path}")
|
|
print(f"Wrote: {summary_csv_path}")
|
|
print(f"Wrote: {summary_json_path}")
|
|
print(f"Wrote: {summary_md_path}")
|
|
print(f"Wrote: {correctness_json_path}")
|
|
|
|
if args.enforce_correctness:
|
|
if missing_references:
|
|
print(
|
|
"Correctness gate failed: missing references for "
|
|
+ ", ".join(sorted(missing_references))
|
|
)
|
|
return 3
|
|
if not overall_correctness_pass:
|
|
print(
|
|
f"Correctness gate failed: one or more audios exceeded max WER={args.max_wer} or max CER={args.max_cer}"
|
|
)
|
|
return 4
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|