This commit is contained in:
Siddharth Balyan 2026-03-28 07:52:20 +00:00 committed by GitHub
commit f0d2fc2826
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 181 additions and 16 deletions

View file

@ -151,7 +151,7 @@
"filename": "atroposlib/envs/eval.py",
"hashed_secret": "829c3804401b0727f70f73d4415e162400cbe57b",
"is_verified": false,
"line_number": 218
"line_number": 225
}
],
"atroposlib/tests/test_reasoning_models.py": [

View file

@ -791,6 +791,30 @@ class BaseEnv(ABC):
writer.write(sample)
logger.info("Evaluation samples saved to %s", samples_filepath)
try:
from atroposlib.frontend.jsonl2html import generate_eval_html
generate_eval_html(samples_filepath)
except Exception as e:
logger.warning("Failed to generate eval HTML viewer: %s", e)
def log_eval_sample(self, sample):
"""Stream-write a single eval sample to samples.jsonl.
Lazy-initializes the writer on first call. Use this inside evaluate()
to write samples as they complete rather than batching at the end.
If using this, omit the samples= parameter from evaluate_log().
"""
if self._eval_sample_writer is None:
if self.config.data_dir_to_save_evals is None:
return
os.makedirs(self.config.data_dir_to_save_evals, exist_ok=True)
self._eval_samples_path = os.path.join(
self.config.data_dir_to_save_evals, "samples.jsonl"
)
self._eval_sample_writer = jsonlines.open(self._eval_samples_path, "w")
self._eval_sample_writer.write(sample)
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
@ -1336,8 +1360,32 @@ class BaseEnv(ABC):
"""
Internal method to run evaluation with proper setup.
"""
self._eval_sample_writer = None
self._eval_samples_path = None
await self.setup()
await self.evaluate()
try:
await self.evaluate()
finally:
# Close streaming eval sample writer if it was used
if self._eval_sample_writer is not None:
self._eval_sample_writer.close()
if self._eval_samples_path:
try:
from atroposlib.frontend.jsonl2html import generate_eval_html
generate_eval_html(self._eval_samples_path)
except Exception as e:
logger.warning("Failed to generate eval HTML: %s", e)
# Close JSONL trajectory writer if it was used
if self.jsonl_writer is not None:
self.jsonl_writer.close()
if self.config.data_path_to_save_groups:
try:
from atroposlib.frontend.jsonl2html import generate_html
generate_html(self.config.data_path_to_save_groups)
except Exception as e:
logger.warning("Failed to generate trajectory HTML: %s", e)
@classmethod
def cli(cls):
@ -1928,6 +1976,10 @@ class BaseEnv(ABC):
env_config_dict_base = default_env_config_from_init.model_dump()
# Apply specific overrides for evaluate mode that are generally useful
env_config_dict_base["use_wandb"] = True
if env_config_dict_base.get("data_dir_to_save_evals") is None:
env_config_dict_base["data_dir_to_save_evals"] = (
f"eval_results/{cls.name or 'eval'}"
)
env_config_dict = merge_dicts(
env_config_dict_base, # `config_init` defaults with evaluate adjustments

View file

@ -104,6 +104,13 @@ def evaluate_log(
writer.write(sample)
print(f"Evaluation samples saved to {samples_filepath}")
try:
from atroposlib.frontend.jsonl2html import generate_eval_html
generate_eval_html(samples_filepath)
except Exception as e:
print(f"Warning: Failed to generate eval HTML viewer: {e}")
class EvalBase(ABC):
""" """

View file

@ -117,6 +117,117 @@ def create_html_for_group(group_data, index):
return group_html
# --- Eval Sample Conversion ---
def _eval_sample_to_viewable(sample):
"""Convert a single eval sample dict to {messages, scores} format for the HTML viewer."""
# Extract score
score = sample.get("score")
if score is None:
is_correct = sample.get("is_correct")
if is_correct is not None:
score = 1.0 if is_correct else 0.0
else:
grade = sample.get("grade", "")
score = 1.0 if grade == "CORRECT" else 0.0
# Build conversation from available fields
if "messages" in sample and isinstance(sample["messages"], list):
conversation = sample["messages"]
else:
conversation = []
question = sample.get("question") or sample.get("problem") or ""
if question:
conversation.append({"role": "user", "content": str(question)})
response = sample.get("model_response") or sample.get("response") or ""
if response:
conversation.append({"role": "assistant", "content": str(response)})
gold = sample.get("gold_answer") or sample.get("answer") or ""
if gold:
conversation.append({"role": "system", "content": f"[Gold Answer]: {gold}"})
if not conversation:
return None
return {"messages": [conversation], "scores": [score]}
def generate_eval_html(input_path, output_path=None):
"""Generate an HTML viewer from eval-format samples.jsonl.
Each line is a flat dict with task-specific fields (question, model_response, score, etc.).
Converts them to the {messages, scores} format used by the existing HTML template.
"""
input_filepath = Path(input_path)
if not input_filepath.is_file():
print(f"Error: Input file not found: {input_filepath}", file=sys.stderr)
return
if output_path is None:
output_filepath = input_filepath.with_suffix(".html")
else:
output_filepath = Path(output_path)
output_filepath.parent.mkdir(parents=True, exist_ok=True)
try:
with open(TEMPLATE_FILE, "r", encoding="utf-8") as f_template:
html_template_content = f_template.read()
except FileNotFoundError:
print(f"Error: Template file not found: {TEMPLATE_FILE}", file=sys.stderr)
return
all_groups_html = []
group_index = 0
with open(input_filepath, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
sample = json.loads(line)
except json.JSONDecodeError:
print(
f"Warning: Skipping line {line_num}. Invalid JSON.",
file=sys.stderr,
)
continue
viewable = _eval_sample_to_viewable(sample)
if viewable is None:
continue
group_html = create_html_for_group(viewable, group_index)
if group_html:
all_groups_html.append(group_html)
group_index += 1
if not all_groups_html:
print("Warning: No valid eval samples to render.", file=sys.stderr)
groups_content = "<p>No data to display.</p>"
else:
groups_content = "\n".join(all_groups_html)
title = f"Eval Results - {input_filepath.name}"
try:
final_html = html_template_content.format(
title=html.escape(title), groups_html=groups_content
)
except KeyError as e:
print(
f"Error: Template missing placeholder: {{{e}}}",
file=sys.stderr,
)
return
with open(output_filepath, "w", encoding="utf-8") as f:
f.write(final_html)
print(f"Generated eval HTML viewer: {output_filepath.absolute()}")
# --- Main Function ---

View file

@ -189,17 +189,18 @@ class GSM8kEnv(BaseEnv):
async def evaluate(self, *args, **kwargs):
start_time = time.time()
eval_tasks = []
for item in self.test:
eval_tasks.append(
self.rollout_and_score_eval(item["question"], item["gold_answer"])
async def rollout_and_log(item):
result = await self.rollout_and_score_eval(
item["question"], item["gold_answer"]
)
if result is not None:
self.log_eval_sample(result.get("sample", result))
return result
eval_tasks = [rollout_and_log(item) for item in self.test]
results = await tqdm_asyncio.gather(*eval_tasks)
# Extract scores and samples
scores = [result["score"] for result in results]
samples = [result["sample"] for result in results]
percent_correct = sum(scores) / len(scores)
end_time = time.time()
@ -207,14 +208,8 @@ class GSM8kEnv(BaseEnv):
# Add to existing metrics for wandb
self.eval_metrics.append(("eval/percent_correct", percent_correct))
# Log evaluation results
eval_metrics = {
"eval/percent_correct": percent_correct,
}
await self.evaluate_log(
metrics=eval_metrics,
samples=samples,
metrics={"eval/percent_correct": percent_correct},
start_time=start_time,
end_time=end_time,
generation_parameters={