add trajectory saving to eval mode.

This commit is contained in:
alt-glitch 2026-03-27 16:04:04 -07:00
parent c421582b6f
commit 7a4edb569c
4 changed files with 144 additions and 4 deletions

View file

@ -18,7 +18,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import aiohttp
import jsonlines
import numpy as np
import wandb
import yaml
from pydantic import BaseModel, Field
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
@ -27,6 +26,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import AutoTokenizer
from typing_extensions import TypedDict
import wandb
from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
from atroposlib.frontend.jsonl2html import generate_html
@ -791,6 +791,13 @@ class BaseEnv(ABC):
writer.write(sample)
logger.info("Evaluation samples saved to %s", samples_filepath)
try:
from atroposlib.frontend.jsonl2html import generate_eval_html
generate_eval_html(samples_filepath)
except Exception as e:
logger.warning("Failed to generate eval HTML viewer: %s", e)
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
@ -1337,7 +1344,18 @@ class BaseEnv(ABC):
Internal method to run evaluation with proper setup.
"""
await self.setup()
await self.evaluate()
try:
await self.evaluate()
finally:
if self.jsonl_writer is not None:
self.jsonl_writer.close()
if self.config.data_path_to_save_groups:
try:
from atroposlib.frontend.jsonl2html import generate_html
generate_html(self.config.data_path_to_save_groups)
except Exception as e:
logger.warning("Failed to generate trajectory HTML: %s", e)
@classmethod
def cli(cls):
@ -1928,6 +1946,10 @@ class BaseEnv(ABC):
env_config_dict_base = default_env_config_from_init.model_dump()
# Apply specific overrides for evaluate mode that are generally useful
env_config_dict_base["use_wandb"] = True
if env_config_dict_base.get("data_dir_to_save_evals") is None:
env_config_dict_base["data_dir_to_save_evals"] = (
f"eval_results/{cls.name or 'eval'}"
)
env_config_dict = merge_dicts(
env_config_dict_base, # `config_init` defaults with evaluate adjustments