mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add trajectory saving to eval mode.
This commit is contained in:
parent
c421582b6f
commit
7a4edb569c
4 changed files with 144 additions and 4 deletions
|
|
@ -117,6 +117,117 @@ def create_html_for_group(group_data, index):
|
|||
return group_html
|
||||
|
||||
|
||||
# --- Eval Sample Conversion ---
|
||||
|
||||
|
||||
def _eval_sample_to_viewable(sample):
|
||||
"""Convert a single eval sample dict to {messages, scores} format for the HTML viewer."""
|
||||
# Extract score
|
||||
score = sample.get("score")
|
||||
if score is None:
|
||||
is_correct = sample.get("is_correct")
|
||||
if is_correct is not None:
|
||||
score = 1.0 if is_correct else 0.0
|
||||
else:
|
||||
grade = sample.get("grade", "")
|
||||
score = 1.0 if grade == "CORRECT" else 0.0
|
||||
|
||||
# Build conversation from available fields
|
||||
if "messages" in sample and isinstance(sample["messages"], list):
|
||||
conversation = sample["messages"]
|
||||
else:
|
||||
conversation = []
|
||||
question = sample.get("question") or sample.get("problem") or ""
|
||||
if question:
|
||||
conversation.append({"role": "user", "content": str(question)})
|
||||
|
||||
response = sample.get("model_response") or sample.get("response") or ""
|
||||
if response:
|
||||
conversation.append({"role": "assistant", "content": str(response)})
|
||||
|
||||
gold = sample.get("gold_answer") or sample.get("answer") or ""
|
||||
if gold:
|
||||
conversation.append({"role": "system", "content": f"[Gold Answer]: {gold}"})
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return {"messages": [conversation], "scores": [score]}
|
||||
|
||||
|
||||
def generate_eval_html(input_path, output_path=None):
|
||||
"""Generate an HTML viewer from eval-format samples.jsonl.
|
||||
|
||||
Each line is a flat dict with task-specific fields (question, model_response, score, etc.).
|
||||
Converts them to the {messages, scores} format used by the existing HTML template.
|
||||
"""
|
||||
input_filepath = Path(input_path)
|
||||
if not input_filepath.is_file():
|
||||
print(f"Error: Input file not found: {input_filepath}", file=sys.stderr)
|
||||
return
|
||||
|
||||
if output_path is None:
|
||||
output_filepath = input_filepath.with_suffix(".html")
|
||||
else:
|
||||
output_filepath = Path(output_path)
|
||||
|
||||
output_filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(TEMPLATE_FILE, "r", encoding="utf-8") as f_template:
|
||||
html_template_content = f_template.read()
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Template file not found: {TEMPLATE_FILE}", file=sys.stderr)
|
||||
return
|
||||
|
||||
all_groups_html = []
|
||||
group_index = 0
|
||||
with open(input_filepath, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
sample = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
f"Warning: Skipping line {line_num}. Invalid JSON.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
viewable = _eval_sample_to_viewable(sample)
|
||||
if viewable is None:
|
||||
continue
|
||||
|
||||
group_html = create_html_for_group(viewable, group_index)
|
||||
if group_html:
|
||||
all_groups_html.append(group_html)
|
||||
group_index += 1
|
||||
|
||||
if not all_groups_html:
|
||||
print("Warning: No valid eval samples to render.", file=sys.stderr)
|
||||
groups_content = "<p>No data to display.</p>"
|
||||
else:
|
||||
groups_content = "\n".join(all_groups_html)
|
||||
|
||||
title = f"Eval Results - {input_filepath.name}"
|
||||
try:
|
||||
final_html = html_template_content.format(
|
||||
title=html.escape(title), groups_html=groups_content
|
||||
)
|
||||
except KeyError as e:
|
||||
print(
|
||||
f"Error: Template missing placeholder: {{{e}}}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return
|
||||
|
||||
with open(output_filepath, "w", encoding="utf-8") as f:
|
||||
f.write(final_html)
|
||||
print(f"Generated eval HTML viewer: {output_filepath.absolute()}")
|
||||
|
||||
|
||||
# --- Main Function ---
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue