add process subcommand

This commit is contained in:
hjc-puro 2025-05-02 03:42:10 -04:00
parent 0f966ec3fb
commit 78cfef9daf
6 changed files with 826 additions and 29 deletions

View file

@ -1,26 +1,42 @@
import asyncio
import json
import logging
import string
import os
import random
import sys
import string
import time
import uuid
import warnings
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
import aiohttp
import jsonlines
import numpy as np
import wandb
import yaml
from pydantic import BaseModel, Field
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
from tenacity import retry, stop_after_attempt, wait_random_exponential
from trajectoryhandler.envs.constants import (
ENV_NAMESPACE,
NAMESPACE_SEP,
OPENAI_NAMESPACE,
SERVER_MANAGER_NAMESPACE,
)
from trajectoryhandler.frontend.jsonl2html import generate_html
from trajectoryhandler.utils.cli import (
adjust_model_defaults,
extract_namespace,
get_double_dash_flags,
get_prefixed_pydantic_model,
merge_dicts,
)
from transformers import AutoTokenizer
import wandb
from atroposlib.type_definitions import UUID
from atroposlib.utils.metrics import get_std_min_max_avg
@ -130,12 +146,15 @@ class BaseEnvConfig(BaseModel):
default=2,
description="Minimum number of items sent before logging, if 0 or less, logs every time",
)
include_messages: bool = Field(
default=False,
description="Whether to include messages in the output transmitted to the trainer",
)
class BaseEnv(ABC):
name = None
env_config_cls = BaseEnvConfig
def __init__(
self,
@ -176,10 +195,22 @@ class BaseEnv(ABC):
self.checkpoint_dir = ""
self.checkpoint_interval = -1
if self.config.data_path_to_save_groups is not None:
if os.path.exists(self.config.data_path_to_save_groups):
raise FileExistsError(
"Data path already exists! Please remove it or change it."
# Find a suitable filename by appending _1, _2, etc. if the file already exists
original_path = self.config.data_path_to_save_groups
counter = 1
path_changed = False
while os.path.exists(self.config.data_path_to_save_groups):
path_obj = Path(original_path)
self.config.data_path_to_save_groups = str(
path_obj.with_stem(f"{path_obj.stem}_{counter}")
)
counter += 1
path_changed = True
if path_changed:
print(
f"Changed data path to {self.config.data_path_to_save_groups} because {original_path} already exists." # noqa: E501
)
self.jsonl_writer = jsonlines.open(
self.config.data_path_to_save_groups, "w"
) # type: jsonlines.Writer
@ -193,7 +224,7 @@ class BaseEnv(ABC):
"""
Initialize the config
"""
return cls.env_config_cls(), ServerBaseline()
return BaseEnvConfig(), ServerBaseline()
async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]:
raise NotImplementedError(
@ -527,6 +558,8 @@ class BaseEnv(ABC):
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
do_send_to_api: bool = True,
abort_on_any_max_length_exceeded: bool = True,
):
"""
Send the chats to the API with robust error handling and support for multiple ScoredDataGroups.
@ -546,6 +579,7 @@ class BaseEnv(ABC):
if self.config.ensure_scores_are_not_same:
if len(set(scored_data["scores"])) == 1:
# Scores are the same, don't send to API
logger.warning("Scores are the same, skipping...")
return
await self.add_rollouts_for_wandb(scored_data, item)
# Check for ref_logprobs
@ -561,16 +595,26 @@ class BaseEnv(ABC):
for mask in scored_data["masks"]:
self.completion_lengths.append(len(mask))
# Add the scores to the queue
if any([len(x) >= self.max_token_len for x in scored_data["tokens"]]):
if abort_on_any_max_length_exceeded and any(
[len(x) >= self.max_token_len for x in scored_data["tokens"]]
):
# Don't send to API if the token length is too long
logger.warning("Token length is too long, skipping...")
return
# Save data, if applicable:
if self.config.include_messages and scored_data.get("messages") is None:
scored_data["messages"] = [
self.tokenizer.decode(scored_data["tokens"][i])
for i in range(group_size)
]
if self.jsonl_writer is not None:
self.jsonl_writer.write(scored_data)
print(f"Wrote scored group to {self.config.data_path_to_save_groups}")
# Send data with retries and error handling
try:
self.items_sent_this_step += 1
await self._send_scored_data_to_api(scored_data)
if do_send_to_api:
self.items_sent_this_step += 1
await self._send_scored_data_to_api(scored_data)
except (Exception, TimeoutError) as e:
print(f"Failed to send scored data after retries: {e}")
@ -733,12 +777,17 @@ class BaseEnv(ABC):
)
)
async def env_manager(self):
async def env_manager(self, use_api=True):
"""
Rollout manager
"""
await self.setup()
await self.setup_wandb()
if not self.use_api:
await self.add_train_workers()
return
await self.register_env()
await self.get_server_info()
# Wait for other instances to get setup :)
@ -783,6 +832,69 @@ class BaseEnv(ABC):
await self.add_train_workers()
await asyncio.sleep(0.1)
async def process_manager(self):
"""
Process manager for running a specific number of groups
"""
await self.setup()
if self.config.use_wandb:
random_id = "".join(random.choices(string.ascii_lowercase, k=6))
current_date = datetime.now().strftime("%Y-%m-%d")
wandb_run_name = f"{self.name}-{current_date}-{random_id}"
wandb.init(project=self.wandb_project, name=wandb_run_name)
# Initialize the processing
self.curr_step = 0
print(f"Starting to process {self.n_groups_to_process} groups...")
# Process the required number of groups
while self.curr_step < self.n_groups_to_process:
# Get an item to process
item = await self.get_next_item()
if item is None:
print("No more items to process")
break
# Process the group
print(f"Processing group {self.curr_step + 1}/{self.n_groups_to_process}")
# Collect trajectories with the specified group size
# Override the group_size temporarily
self.config.group_size = self.group_size_to_process
# Collect and process the trajectories
to_postprocess, _ = await self.collect_trajectories(item)
if to_postprocess:
# Post-process the trajectories
processed_data = await self.postprocess_histories(to_postprocess)
# Save to output file (don't send to API)
await self.handle_send_to_api(
processed_data,
item,
do_send_to_api=False,
abort_on_any_max_length_exceeded=False,
)
await self.wandb_log()
self.curr_step += 1
print(
f"Successfully processed group {self.curr_step}/{self.n_groups_to_process}"
)
else:
print("Failed to process group, retrying...")
print(f"Completed processing {self.curr_step} groups")
# Close the output file if it's open
if self.jsonl_writer is not None:
self.jsonl_writer.close()
generate_html(self.config.data_path_to_save_groups)
@classmethod
def cli(cls):
"""
@ -804,10 +916,8 @@ class BaseEnv(ABC):
print()
print(ex.message.split("error: ")[-1])
return 2
else:
# For any other exception
print(f"Error: {str(ex)}", file=sys.stderr)
return 1
raise ex
run_and_exit(
subcommands,
@ -825,9 +935,18 @@ class BaseEnv(ABC):
"""
env_config, server_configs = cls.config_init()
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
server_full_prefix = f"{SERVER_MANAGER_NAMESPACE}{NAMESPACE_SEP}"
class CliServeConfig(
cls.env_config_cls, OpenaiConfig, ServerManagerConfig, Cmd
get_prefixed_pydantic_model(type(env_config), env_full_prefix),
get_prefixed_pydantic_model(OpenaiConfig, openai_full_prefix),
get_prefixed_pydantic_model(
ServerManagerConfig,
server_full_prefix,
),
Cmd,
):
"""
Configuration for the serve command.
@ -837,8 +956,9 @@ class BaseEnv(ABC):
def run(self) -> None:
"""The logic to execute for the 'serve' command."""
# Convert this config into the formats needed by BaseEnv
if self.wandb_name is None and cls.name is not None:
self.wandb_name = cls.name
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if getattr(self, wandb_name_attr) is None and cls.name is not None:
setattr(self, wandb_name_attr, cls.name)
model_dumped = self.model_dump(exclude_unset=True)
server_manager_config = ServerManagerConfig(**model_dumped)
# Create the environment instance
@ -863,29 +983,133 @@ class BaseEnv(ABC):
type: The CliProcessConfig class for processing commands.
"""
class CliProcessConfig(Cmd):
PROCESS_MODE_ENV_DEFAULT_CONFIG = BaseEnvConfig(
group_size=8,
total_steps=2,
ensure_scores_are_not_same=False,
include_messages=True,
)
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = OpenaiConfig(
model_name="gpt-4.1-nano",
base_url=None,
api_key=os.environ.get("OPENAI_API_KEY"),
)
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG = ServerManagerConfig(
slurm=False,
testing=False,
)
default_env_config, default_openai_config = cls.config_init()
if isinstance(default_openai_config, list):
default_openai_config = default_openai_config[0]
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
server_full_prefix = f"{SERVER_MANAGER_NAMESPACE}{NAMESPACE_SEP}"
env_config_cls_new_defaults = adjust_model_defaults(
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
)
openai_config_cls_new_defaults = adjust_model_defaults(
OpenaiConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG
)
server_manager_config_cls_new_defaults = adjust_model_defaults(
ServerManagerConfig,
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG,
)
class CliProcessConfig(
get_prefixed_pydantic_model(env_config_cls_new_defaults, env_full_prefix),
get_prefixed_pydantic_model(
openai_config_cls_new_defaults, openai_full_prefix
),
get_prefixed_pydantic_model(
server_manager_config_cls_new_defaults, server_full_prefix
),
Cmd,
):
"""
Configuration for the process command.
This is a placeholder for future implementation.
"""
# Add process-specific fields here
group_size: int = Field(
default=4, description="Number of responses per prompt"
)
n_groups: int = Field(default=1, description="Number of groups to process")
output_file: str = Field(
..., description="Path to jsonl file to write results"
config: str | None = Field(
default=None,
description="Path to .yaml config file. CLI args override this.",
)
def run(self) -> None:
"""The logic to execute for the 'process' command."""
# Setup environment configuration
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
if getattr(self, wandb_name_attr) is None and cls.name is not None:
setattr(self, wandb_name_attr, cls.name)
if self.config is not None:
with open(self.config, "r") as f:
config = yaml.safe_load(f)
print(f"Loaded config from {self.config}")
else:
config = {}
cli_passed_flags = get_double_dash_flags()
# cli args overrides config file which overrides class defaults which overrides process mode defaults
env_config = env_config_cls_new_defaults(
**merge_dicts(
default_env_config.model_dump(),
PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(),
config.get(ENV_NAMESPACE, {}),
extract_namespace(
cli_passed_flags, env_full_prefix
), # only extract namespace for cli-passed args
)
)
openai_config = openai_config_cls_new_defaults(
**merge_dicts(
default_openai_config.model_dump(),
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(),
config.get(OPENAI_NAMESPACE, {}),
extract_namespace(
cli_passed_flags, openai_full_prefix
), # only extract namespace for cli-passed args
)
)
server_manager_config = server_manager_config_cls_new_defaults(
**merge_dicts(
ServerManagerConfig().model_dump(),
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(),
config.get(SERVER_MANAGER_NAMESPACE, {}),
extract_namespace(
cli_passed_flags, server_full_prefix
), # only extract namespace for cli-passed args
)
)
# Create the environment instance
env = cls(
config=env_config,
server_configs=[openai_config],
slurm=server_manager_config.slurm,
testing=server_manager_config.testing,
)
# Set the process mode parameters
env.process_mode = True
env.n_groups_to_process = env_config.total_steps
env.group_size_to_process = env_config.group_size
print(
f"Processing {self.n_groups} groups of "
f"{self.group_size} responses and "
f"writing to {self.output_file}"
f"Processing {env_config.total_steps} groups of "
f"{env_config.group_size} responses and "
f"writing to {env_config.data_path_to_save_groups}"
)
print("This is a placeholder implementation for the process command.")
asyncio.run(env.process_manager())
# Actual implementation would go here
return CliProcessConfig