add trajectory saving to eval mode.

This commit is contained in:
alt-glitch 2026-03-27 16:04:04 -07:00
parent c421582b6f
commit 7a4edb569c
4 changed files with 144 additions and 4 deletions

View file

@ -117,6 +117,117 @@ def create_html_for_group(group_data, index):
return group_html
# --- Eval Sample Conversion ---
def _eval_sample_to_viewable(sample):
"""Convert a single eval sample dict to {messages, scores} format for the HTML viewer."""
# Extract score
score = sample.get("score")
if score is None:
is_correct = sample.get("is_correct")
if is_correct is not None:
score = 1.0 if is_correct else 0.0
else:
grade = sample.get("grade", "")
score = 1.0 if grade == "CORRECT" else 0.0
# Build conversation from available fields
if "messages" in sample and isinstance(sample["messages"], list):
conversation = sample["messages"]
else:
conversation = []
question = sample.get("question") or sample.get("problem") or ""
if question:
conversation.append({"role": "user", "content": str(question)})
response = sample.get("model_response") or sample.get("response") or ""
if response:
conversation.append({"role": "assistant", "content": str(response)})
gold = sample.get("gold_answer") or sample.get("answer") or ""
if gold:
conversation.append({"role": "system", "content": f"[Gold Answer]: {gold}"})
if not conversation:
return None
return {"messages": [conversation], "scores": [score]}
def generate_eval_html(input_path, output_path=None):
"""Generate an HTML viewer from eval-format samples.jsonl.
Each line is a flat dict with task-specific fields (question, model_response, score, etc.).
Converts them to the {messages, scores} format used by the existing HTML template.
"""
input_filepath = Path(input_path)
if not input_filepath.is_file():
print(f"Error: Input file not found: {input_filepath}", file=sys.stderr)
return
if output_path is None:
output_filepath = input_filepath.with_suffix(".html")
else:
output_filepath = Path(output_path)
output_filepath.parent.mkdir(parents=True, exist_ok=True)
try:
with open(TEMPLATE_FILE, "r", encoding="utf-8") as f_template:
html_template_content = f_template.read()
except FileNotFoundError:
print(f"Error: Template file not found: {TEMPLATE_FILE}", file=sys.stderr)
return
all_groups_html = []
group_index = 0
with open(input_filepath, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
sample = json.loads(line)
except json.JSONDecodeError:
print(
f"Warning: Skipping line {line_num}. Invalid JSON.",
file=sys.stderr,
)
continue
viewable = _eval_sample_to_viewable(sample)
if viewable is None:
continue
group_html = create_html_for_group(viewable, group_index)
if group_html:
all_groups_html.append(group_html)
group_index += 1
if not all_groups_html:
print("Warning: No valid eval samples to render.", file=sys.stderr)
groups_content = "<p>No data to display.</p>"
else:
groups_content = "\n".join(all_groups_html)
title = f"Eval Results - {input_filepath.name}"
try:
final_html = html_template_content.format(
title=html.escape(title), groups_html=groups_content
)
except KeyError as e:
print(
f"Error: Template missing placeholder: {{{e}}}",
file=sys.stderr,
)
return
with open(output_filepath, "w", encoding="utf-8") as f:
f.write(final_html)
print(f"Generated eval HTML viewer: {output_filepath.absolute()}")
# --- Main Function ---