mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Merge 0ab46d65b0 into c20c85256e
This commit is contained in:
commit
f0d2fc2826
5 changed files with 181 additions and 16 deletions
|
|
@ -151,7 +151,7 @@
|
|||
"filename": "atroposlib/envs/eval.py",
|
||||
"hashed_secret": "829c3804401b0727f70f73d4415e162400cbe57b",
|
||||
"is_verified": false,
|
||||
"line_number": 218
|
||||
"line_number": 225
|
||||
}
|
||||
],
|
||||
"atroposlib/tests/test_reasoning_models.py": [
|
||||
|
|
|
|||
|
|
@ -791,6 +791,30 @@ class BaseEnv(ABC):
|
|||
writer.write(sample)
|
||||
logger.info("Evaluation samples saved to %s", samples_filepath)
|
||||
|
||||
try:
|
||||
from atroposlib.frontend.jsonl2html import generate_eval_html
|
||||
|
||||
generate_eval_html(samples_filepath)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate eval HTML viewer: %s", e)
|
||||
|
||||
def log_eval_sample(self, sample):
|
||||
"""Stream-write a single eval sample to samples.jsonl.
|
||||
|
||||
Lazy-initializes the writer on first call. Use this inside evaluate()
|
||||
to write samples as they complete rather than batching at the end.
|
||||
If using this, omit the samples= parameter from evaluate_log().
|
||||
"""
|
||||
if self._eval_sample_writer is None:
|
||||
if self.config.data_dir_to_save_evals is None:
|
||||
return
|
||||
os.makedirs(self.config.data_dir_to_save_evals, exist_ok=True)
|
||||
self._eval_samples_path = os.path.join(
|
||||
self.config.data_dir_to_save_evals, "samples.jsonl"
|
||||
)
|
||||
self._eval_sample_writer = jsonlines.open(self._eval_samples_path, "w")
|
||||
self._eval_sample_writer.write(sample)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(multiplier=1, max=10),
|
||||
|
|
@ -1336,8 +1360,32 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
Internal method to run evaluation with proper setup.
|
||||
"""
|
||||
self._eval_sample_writer = None
|
||||
self._eval_samples_path = None
|
||||
await self.setup()
|
||||
await self.evaluate()
|
||||
try:
|
||||
await self.evaluate()
|
||||
finally:
|
||||
# Close streaming eval sample writer if it was used
|
||||
if self._eval_sample_writer is not None:
|
||||
self._eval_sample_writer.close()
|
||||
if self._eval_samples_path:
|
||||
try:
|
||||
from atroposlib.frontend.jsonl2html import generate_eval_html
|
||||
|
||||
generate_eval_html(self._eval_samples_path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate eval HTML: %s", e)
|
||||
# Close JSONL trajectory writer if it was used
|
||||
if self.jsonl_writer is not None:
|
||||
self.jsonl_writer.close()
|
||||
if self.config.data_path_to_save_groups:
|
||||
try:
|
||||
from atroposlib.frontend.jsonl2html import generate_html
|
||||
|
||||
generate_html(self.config.data_path_to_save_groups)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate trajectory HTML: %s", e)
|
||||
|
||||
@classmethod
|
||||
def cli(cls):
|
||||
|
|
@ -1928,6 +1976,10 @@ class BaseEnv(ABC):
|
|||
env_config_dict_base = default_env_config_from_init.model_dump()
|
||||
# Apply specific overrides for evaluate mode that are generally useful
|
||||
env_config_dict_base["use_wandb"] = True
|
||||
if env_config_dict_base.get("data_dir_to_save_evals") is None:
|
||||
env_config_dict_base["data_dir_to_save_evals"] = (
|
||||
f"eval_results/{cls.name or 'eval'}"
|
||||
)
|
||||
|
||||
env_config_dict = merge_dicts(
|
||||
env_config_dict_base, # `config_init` defaults with evaluate adjustments
|
||||
|
|
|
|||
|
|
@ -104,6 +104,13 @@ def evaluate_log(
|
|||
writer.write(sample)
|
||||
print(f"Evaluation samples saved to {samples_filepath}")
|
||||
|
||||
try:
|
||||
from atroposlib.frontend.jsonl2html import generate_eval_html
|
||||
|
||||
generate_eval_html(samples_filepath)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to generate eval HTML viewer: {e}")
|
||||
|
||||
|
||||
class EvalBase(ABC):
|
||||
""" """
|
||||
|
|
|
|||
|
|
@ -117,6 +117,117 @@ def create_html_for_group(group_data, index):
|
|||
return group_html
|
||||
|
||||
|
||||
# --- Eval Sample Conversion ---
|
||||
|
||||
|
||||
def _eval_sample_to_viewable(sample):
|
||||
"""Convert a single eval sample dict to {messages, scores} format for the HTML viewer."""
|
||||
# Extract score
|
||||
score = sample.get("score")
|
||||
if score is None:
|
||||
is_correct = sample.get("is_correct")
|
||||
if is_correct is not None:
|
||||
score = 1.0 if is_correct else 0.0
|
||||
else:
|
||||
grade = sample.get("grade", "")
|
||||
score = 1.0 if grade == "CORRECT" else 0.0
|
||||
|
||||
# Build conversation from available fields
|
||||
if "messages" in sample and isinstance(sample["messages"], list):
|
||||
conversation = sample["messages"]
|
||||
else:
|
||||
conversation = []
|
||||
question = sample.get("question") or sample.get("problem") or ""
|
||||
if question:
|
||||
conversation.append({"role": "user", "content": str(question)})
|
||||
|
||||
response = sample.get("model_response") or sample.get("response") or ""
|
||||
if response:
|
||||
conversation.append({"role": "assistant", "content": str(response)})
|
||||
|
||||
gold = sample.get("gold_answer") or sample.get("answer") or ""
|
||||
if gold:
|
||||
conversation.append({"role": "system", "content": f"[Gold Answer]: {gold}"})
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
return {"messages": [conversation], "scores": [score]}
|
||||
|
||||
|
||||
def generate_eval_html(input_path, output_path=None):
|
||||
"""Generate an HTML viewer from eval-format samples.jsonl.
|
||||
|
||||
Each line is a flat dict with task-specific fields (question, model_response, score, etc.).
|
||||
Converts them to the {messages, scores} format used by the existing HTML template.
|
||||
"""
|
||||
input_filepath = Path(input_path)
|
||||
if not input_filepath.is_file():
|
||||
print(f"Error: Input file not found: {input_filepath}", file=sys.stderr)
|
||||
return
|
||||
|
||||
if output_path is None:
|
||||
output_filepath = input_filepath.with_suffix(".html")
|
||||
else:
|
||||
output_filepath = Path(output_path)
|
||||
|
||||
output_filepath.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(TEMPLATE_FILE, "r", encoding="utf-8") as f_template:
|
||||
html_template_content = f_template.read()
|
||||
except FileNotFoundError:
|
||||
print(f"Error: Template file not found: {TEMPLATE_FILE}", file=sys.stderr)
|
||||
return
|
||||
|
||||
all_groups_html = []
|
||||
group_index = 0
|
||||
with open(input_filepath, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
sample = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
print(
|
||||
f"Warning: Skipping line {line_num}. Invalid JSON.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
continue
|
||||
|
||||
viewable = _eval_sample_to_viewable(sample)
|
||||
if viewable is None:
|
||||
continue
|
||||
|
||||
group_html = create_html_for_group(viewable, group_index)
|
||||
if group_html:
|
||||
all_groups_html.append(group_html)
|
||||
group_index += 1
|
||||
|
||||
if not all_groups_html:
|
||||
print("Warning: No valid eval samples to render.", file=sys.stderr)
|
||||
groups_content = "<p>No data to display.</p>"
|
||||
else:
|
||||
groups_content = "\n".join(all_groups_html)
|
||||
|
||||
title = f"Eval Results - {input_filepath.name}"
|
||||
try:
|
||||
final_html = html_template_content.format(
|
||||
title=html.escape(title), groups_html=groups_content
|
||||
)
|
||||
except KeyError as e:
|
||||
print(
|
||||
f"Error: Template missing placeholder: {{{e}}}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return
|
||||
|
||||
with open(output_filepath, "w", encoding="utf-8") as f:
|
||||
f.write(final_html)
|
||||
print(f"Generated eval HTML viewer: {output_filepath.absolute()}")
|
||||
|
||||
|
||||
# --- Main Function ---
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -189,17 +189,18 @@ class GSM8kEnv(BaseEnv):
|
|||
async def evaluate(self, *args, **kwargs):
|
||||
start_time = time.time()
|
||||
|
||||
eval_tasks = []
|
||||
for item in self.test:
|
||||
eval_tasks.append(
|
||||
self.rollout_and_score_eval(item["question"], item["gold_answer"])
|
||||
async def rollout_and_log(item):
|
||||
result = await self.rollout_and_score_eval(
|
||||
item["question"], item["gold_answer"]
|
||||
)
|
||||
if result is not None:
|
||||
self.log_eval_sample(result.get("sample", result))
|
||||
return result
|
||||
|
||||
eval_tasks = [rollout_and_log(item) for item in self.test]
|
||||
results = await tqdm_asyncio.gather(*eval_tasks)
|
||||
|
||||
# Extract scores and samples
|
||||
scores = [result["score"] for result in results]
|
||||
samples = [result["sample"] for result in results]
|
||||
|
||||
percent_correct = sum(scores) / len(scores)
|
||||
|
||||
end_time = time.time()
|
||||
|
|
@ -207,14 +208,8 @@ class GSM8kEnv(BaseEnv):
|
|||
# Add to existing metrics for wandb
|
||||
self.eval_metrics.append(("eval/percent_correct", percent_correct))
|
||||
|
||||
# Log evaluation results
|
||||
eval_metrics = {
|
||||
"eval/percent_correct": percent_correct,
|
||||
}
|
||||
|
||||
await self.evaluate_log(
|
||||
metrics=eval_metrics,
|
||||
samples=samples,
|
||||
metrics={"eval/percent_correct": percent_correct},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
generation_parameters={
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue