Merge branch 'main' into 2025-05-03-http-error-logging

This commit is contained in:
hjc-puro 2025-05-10 17:09:22 +08:00 committed by GitHub
commit a659217afe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 1093 additions and 475 deletions

View file

@ -20,15 +20,12 @@ import wandb
import yaml
from pydantic import BaseModel, Field
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
from rich import print as rprint
from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import AutoTokenizer
from atroposlib.envs.constants import (
ENV_NAMESPACE,
NAMESPACE_SEP,
OPENAI_NAMESPACE,
SERVER_MANAGER_NAMESPACE,
)
from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
from atroposlib.frontend.jsonl2html import generate_html
from atroposlib.type_definitions import UUID
from atroposlib.utils.cli import (
@ -116,7 +113,7 @@ class BaseEnvConfig(BaseModel):
default=3, description="Maximum number of batches to have in queue."
)
tokenizer_name: str = Field(
default="NousResearch/DeepHermes-3-Llama-3-1B-Preview",
default="NousResearch/DeepHermes-3-Llama-3-3B-Preview",
description="Hugging Face tokenzer to use.",
)
use_wandb: bool = Field(default=True, description="Whether to use wandb")
@ -366,35 +363,55 @@ class BaseEnv(ABC):
)
break
@retry(
stop=stop_after_attempt(3),
wait=wait_random_exponential(multiplier=1, max=10),
)
async def _register_env(self):
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.rollout_server_url}/register-env",
json={
"max_token_length": self.config.max_token_length,
"desired_name": self.config.wandb_name,
"weight": self.config.inference_weight,
},
) as resp:
data = await resp.json()
return data
except Exception as e:
logger.error(f"Error registering env: {e}")
raise e
async def register_env(self):
# Now register the env...
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.config.rollout_server_url}/register-env",
json={
"max_token_length": self.config.max_token_length,
"desired_name": self.config.wandb_name,
"weight": self.config.inference_weight,
},
) as resp:
data = await parse_http_response(resp, logger)
self.env_id = data["env_id"]
self.wandb_prepend = data["wandb_name"]
self.curr_step = data["starting_step"]
self.checkpoint_dir = data["checkpoint_dir"]
self.checkpoint_interval = data["checkpoint_interval"]
if self.config.total_steps == -1:
self.config.total_steps = data["num_steps"]
if self.config.total_steps == -1:
raise ValueError("Total steps not set in config or server!")
print(
f"Initialized env with id {self.env_id}: "
f"curr_step: {self.curr_step}, "
f"checkpoint_dir: {self.checkpoint_dir}, "
f"checkpoint_interval: {self.checkpoint_interval}"
while True:
data = await self._register_env()
if data["status"] != "success":
logging.warning(
f"Waiting to register the env due to status {data['status']}"
)
if self.curr_step > 0:
self.load_checkpoint()
await asyncio.sleep(1)
continue
self.env_id = data["env_id"]
self.wandb_prepend = data["wandb_name"]
self.curr_step = data["starting_step"]
self.checkpoint_dir = data["checkpoint_dir"]
self.checkpoint_interval = data["checkpoint_interval"]
if self.config.total_steps == -1:
self.config.total_steps = data["num_steps"]
if self.config.total_steps == -1:
raise ValueError("Total steps not set in config or server!")
print(
f"Initialized env with id {self.env_id}: "
f"curr_step: {self.curr_step}, "
f"checkpoint_dir: {self.checkpoint_dir}, "
f"checkpoint_interval: {self.checkpoint_interval}"
)
if self.curr_step > 0:
self.load_checkpoint()
break
async def get_server_info(self):
"""
@ -940,36 +957,150 @@ class BaseEnv(ABC):
type: The CliServeConfig class for serving commands.
"""
env_config, server_configs = cls.config_init()
# Get the default configurations defined by the specific environment class
default_env_config, default_server_configs = cls.config_init()
# Define namespace prefixes for CLI arguments and YAML keys
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
# Define the CLI configuration class dynamically
class CliServeConfig(
get_prefixed_pydantic_model(type(env_config), env_full_prefix),
get_prefixed_pydantic_model(OpenaiConfig, openai_full_prefix),
ServerManagerConfig,
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
get_prefixed_pydantic_model(
OpenaiConfig, openai_full_prefix
), # Use OpenaiConfig for CLI args
ServerManagerConfig, # ServerManager args are not namespaced by default
Cmd,
):
"""
Configuration for the serve command.
This combines BaseEnvConfig and OpenaiConfig into a single command.
Supports overrides via YAML config file and CLI arguments.
Order of precedence: CLI > YAML > Class Defaults.
"""
config: str | None = Field(
default=None,
description="Path to .yaml config file. CLI args override this.",
)
def run(self) -> None:
"""The logic to execute for the 'serve' command."""
# Convert this config into the formats needed by BaseEnv
# Set default wandb name if not provided and class has a name
# Note: This modifies the 'self' instance based on CLI args before full parsing.
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if getattr(self, wandb_name_attr) is None and cls.name is not None:
if (
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
):
setattr(self, wandb_name_attr, cls.name)
model_dumped = self.model_dump(exclude_unset=True)
server_manager_config = ServerManagerConfig(**model_dumped)
# Create the environment instance
# Load configuration from YAML file if specified
if self.config is not None:
with open(self.config, "r") as f:
yaml_config = yaml.safe_load(f)
print(f"Loaded config from {self.config}")
else:
yaml_config = {}
# Get CLI flags passed with double dashes (e.g., --env--foo bar)
cli_passed_flags = get_double_dash_flags()
# --- Configuration Merging ---
# Priority: CLI > YAML > Class Defaults
# 1. Environment Configuration
env_config_dict = merge_dicts(
default_env_config.model_dump(), # Class Defaults
yaml_config.get(ENV_NAMESPACE, {}), # YAML config
extract_namespace(cli_passed_flags, env_full_prefix), # CLI args
)
# 2. OpenAI Configuration (used for potential overrides)
oai_cli_passed_args = extract_namespace(
cli_passed_flags, openai_full_prefix
) # CLI args
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
if isinstance(default_server_configs, ServerBaseline) and (
oai_cli_passed_args or yaml_oai_config
):
raise ValueError(
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501
)
if (
isinstance(default_server_configs, list)
and len(default_server_configs) == 1
):
# can't use the same var name because it shadows the class variable and we get an error
default_openai_config_ = default_server_configs[0]
else:
default_openai_config_ = default_server_configs
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, OpenaiConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
yaml_oai_config,
oai_cli_passed_args,
)
else:
openai_config_dict = {}
# 3. Server Manager Configuration (slurm, testing - not namespaced)
# Extract only relevant CLI flags for ServerManager
server_manager_cli_passed_flags = {}
if "slurm" in cli_passed_flags:
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
if "testing" in cli_passed_flags:
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
"testing"
]
server_manager_yaml_dict = {}
if "slurm" in yaml_config:
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
if "testing" in yaml_config:
server_manager_yaml_dict["testing"] = yaml_config["testing"]
server_manager_config_dict = merge_dicts(
ServerManagerConfig().model_dump(), # Base defaults for ServerManager
server_manager_yaml_dict, # YAML config
server_manager_cli_passed_flags, # CLI args
)
# --- Instantiate Final Config Objects ---
# Create instances from the merged dictionaries using the original default types where appropriate
# Instantiate the final environment config using its original type
env_config = type(default_env_config)(**env_config_dict)
# Instantiate the final server manager config
server_manager_config = ServerManagerConfig(
**server_manager_config_dict
)
# Determine the final server_configs, handling single, multiple servers, and overrides.
openai_configs = resolve_openai_configs(
default_server_configs=default_server_configs,
openai_config_dict=openai_config_dict,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
logger=logger,
)
# --- Create and Run Environment ---
# Create the environment instance using the final, instantiated config objects
env = cls(
config=env_config,
server_configs=server_configs,
server_configs=openai_configs,
slurm=server_manager_config.slurm,
testing=server_manager_config.testing,
)
rprint(env_config)
rprint(openai_configs)
# Run the environment
asyncio.run(env.env_manager())
@ -985,11 +1116,14 @@ class BaseEnv(ABC):
type: The CliProcessConfig class for processing commands.
"""
# Define specific default configurations for the 'process' mode
PROCESS_MODE_ENV_DEFAULT_CONFIG = BaseEnvConfig(
group_size=8,
total_steps=2,
ensure_scores_are_not_same=False,
include_messages=True,
data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl",
use_wandb=True,
)
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = OpenaiConfig(
model_name="gpt-4.1-nano",
@ -1001,19 +1135,24 @@ class BaseEnv(ABC):
testing=False,
)
default_env_config, default_openai_config = cls.config_init()
if isinstance(default_openai_config, list):
default_openai_config = default_openai_config[0]
# Get the base default configurations from the specific environment class
(
default_env_config,
default_server_configs,
) = cls.config_init()
# Define namespace prefixes
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
# Create Pydantic model classes with the 'process' mode defaults applied.
# These adjusted classes will be used for final instantiation.
env_config_cls_new_defaults = adjust_model_defaults(
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
)
openai_config_cls_new_defaults = adjust_model_defaults(
OpenaiConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG
OpenaiConfig,
PROCESS_MODE_OPENAI_DEFAULT_CONFIG,
)
server_manager_config_cls_new_defaults = adjust_model_defaults(
ServerManagerConfig,
@ -1030,6 +1169,8 @@ class BaseEnv(ABC):
):
"""
Configuration for the process command.
Supports overrides via YAML config file and CLI arguments.
Order of precedence: CLI > YAML > Process Mode Defaults > `config_init` defaults.
"""
config: str | None = Field(
@ -1039,42 +1180,72 @@ class BaseEnv(ABC):
def run(self) -> None:
"""The logic to execute for the 'process' command."""
# Setup environment configuration
# Set default wandb name if not provided and class has a name
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if getattr(self, wandb_name_attr) is None and cls.name is not None:
if (
getattr(self, wandb_name_attr, None) is None
and cls.name is not None
):
setattr(self, wandb_name_attr, cls.name)
# Load configuration from YAML file if specified
if self.config is not None:
with open(self.config, "r") as f:
config = yaml.safe_load(f)
yaml_config = yaml.safe_load(f)
print(f"Loaded config from {self.config}")
else:
config = {}
yaml_config = {}
# Get CLI flags passed with double dashes
cli_passed_flags = get_double_dash_flags()
# cli args overrides config file which overrides class defaults which overrides process mode defaults
env_config = env_config_cls_new_defaults(
**merge_dicts(
default_env_config.model_dump(),
PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(),
config.get(ENV_NAMESPACE, {}),
extract_namespace(
cli_passed_flags, env_full_prefix
), # only extract namespace for cli-passed args
)
)
openai_config = openai_config_cls_new_defaults(
**merge_dicts(
default_openai_config.model_dump(),
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(),
config.get(OPENAI_NAMESPACE, {}),
extract_namespace(
cli_passed_flags, openai_full_prefix
), # only extract namespace for cli-passed args
)
# --- Configuration Merging ---
# Priority: CLI > YAML > Process Mode Defaults > `config_init` defaults
# 1. Environment Configuration
env_config_dict = merge_dicts(
default_env_config.model_dump(), # Class Defaults
PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
yaml_config.get(ENV_NAMESPACE, {}), # YAML config
extract_namespace(cli_passed_flags, env_full_prefix), # CLI args
)
# 2. OpenAI Configuration
oai_cli_passed_args = extract_namespace(
cli_passed_flags, openai_full_prefix
) # CLI args
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
if isinstance(default_server_configs, ServerBaseline) and (
oai_cli_passed_args or yaml_oai_config
):
raise ValueError(
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use OpenaiConfig." # noqa: E501
)
if (
isinstance(default_server_configs, list)
and len(default_server_configs) == 1
):
# can't use the same var name because it shadows the class variable and we get an error
default_openai_config_ = default_server_configs[0]
else:
default_openai_config_ = default_server_configs
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
yaml_oai_config = yaml_oai_config[0]
if isinstance(default_openai_config_, OpenaiConfig) and isinstance(
yaml_oai_config, dict
):
openai_config_dict = merge_dicts(
default_openai_config_.model_dump(), # Default OpenaiConfig (or from class init)
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
yaml_oai_config,
oai_cli_passed_args,
)
else:
openai_config_dict = {}
# 3. Server Manager Configuration
# Extract only relevant CLI flags
server_manager_cli_passed_flags = {}
if "slurm" in cli_passed_flags:
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
@ -1083,36 +1254,68 @@ class BaseEnv(ABC):
"testing"
]
server_manager_config = server_manager_config_cls_new_defaults(
**merge_dicts(
ServerManagerConfig().model_dump(),
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(),
config.get(SERVER_MANAGER_NAMESPACE, {}),
server_manager_cli_passed_flags,
)
server_manager_yaml_dict = {}
if "slurm" in yaml_config:
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
if "testing" in yaml_config:
server_manager_yaml_dict["testing"] = yaml_config["testing"]
server_manager_config_dict = merge_dicts(
ServerManagerConfig().model_dump(), # Base defaults
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
server_manager_yaml_dict,
server_manager_cli_passed_flags, # CLI args
)
# --- Instantiate Final Config Objects ---
# Use the classes with adjusted defaults for instantiation
env_config = env_config_cls_new_defaults(**env_config_dict)
server_manager_config = server_manager_config_cls_new_defaults(
**server_manager_config_dict
)
# Determine the final server_configs, handling single, multiple servers, and overrides.
openai_configs = resolve_openai_configs(
default_server_configs=default_server_configs,
openai_config_dict=openai_config_dict,
yaml_config=yaml_config,
cli_passed_flags=cli_passed_flags,
logger=logger,
)
rprint(env_config)
rprint(openai_configs)
# --- Create and Run Environment ---
# Create the environment instance
env = cls(
config=env_config,
server_configs=[openai_config],
server_configs=openai_configs,
slurm=server_manager_config.slurm,
testing=server_manager_config.testing,
)
# Set the process mode parameters
# Set specific parameters for process mode on the environment instance
env.process_mode = True
env.n_groups_to_process = env_config.total_steps
env.group_size_to_process = env_config.group_size
# Validate that an output path is set (should have a default from PROCESS_MODE_ENV_DEFAULT_CONFIG)
if env_config.data_path_to_save_groups is None:
# This check might be redundant if the default is always set, but good practice.
raise ValueError(
"data_path_to_save_groups must be set for process mode"
)
print(
f"Processing {env_config.total_steps} groups of "
f"{env_config.group_size} responses and "
f"writing to {env_config.data_path_to_save_groups}"
)
# Run the environment's asynchronous process manager function
asyncio.run(env.process_manager())
# Actual implementation would go here
return CliProcessConfig