mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
add trajectory saving to eval mode.
This commit is contained in:
parent
c421582b6f
commit
7a4edb569c
4 changed files with 144 additions and 4 deletions
|
|
@ -18,7 +18,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||
import aiohttp
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import wandb
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
|
||||
|
|
@ -27,6 +26,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from transformers import AutoTokenizer
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import wandb
|
||||
from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
|
||||
from atroposlib.frontend.jsonl2html import generate_html
|
||||
|
|
@ -791,6 +791,13 @@ class BaseEnv(ABC):
|
|||
writer.write(sample)
|
||||
logger.info("Evaluation samples saved to %s", samples_filepath)
|
||||
|
||||
try:
|
||||
from atroposlib.frontend.jsonl2html import generate_eval_html
|
||||
|
||||
generate_eval_html(samples_filepath)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate eval HTML viewer: %s", e)
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_random_exponential(multiplier=1, max=10),
|
||||
|
|
@ -1337,7 +1344,18 @@ class BaseEnv(ABC):
|
|||
Internal method to run evaluation with proper setup.
|
||||
"""
|
||||
await self.setup()
|
||||
await self.evaluate()
|
||||
try:
|
||||
await self.evaluate()
|
||||
finally:
|
||||
if self.jsonl_writer is not None:
|
||||
self.jsonl_writer.close()
|
||||
if self.config.data_path_to_save_groups:
|
||||
try:
|
||||
from atroposlib.frontend.jsonl2html import generate_html
|
||||
|
||||
generate_html(self.config.data_path_to_save_groups)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to generate trajectory HTML: %s", e)
|
||||
|
||||
@classmethod
|
||||
def cli(cls):
|
||||
|
|
@ -1928,6 +1946,10 @@ class BaseEnv(ABC):
|
|||
env_config_dict_base = default_env_config_from_init.model_dump()
|
||||
# Apply specific overrides for evaluate mode that are generally useful
|
||||
env_config_dict_base["use_wandb"] = True
|
||||
if env_config_dict_base.get("data_dir_to_save_evals") is None:
|
||||
env_config_dict_base["data_dir_to_save_evals"] = (
|
||||
f"eval_results/{cls.name or 'eval'}"
|
||||
)
|
||||
|
||||
env_config_dict = merge_dicts(
|
||||
env_config_dict_base, # `config_init` defaults with evaluate adjustments
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue