mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
add process subcommand
This commit is contained in:
parent
0f966ec3fb
commit
78cfef9daf
6 changed files with 826 additions and 29 deletions
|
|
@ -1,26 +1,42 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import string
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import aiohttp
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
import wandb
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
from trajectoryhandler.envs.constants import (
|
||||
ENV_NAMESPACE,
|
||||
NAMESPACE_SEP,
|
||||
OPENAI_NAMESPACE,
|
||||
SERVER_MANAGER_NAMESPACE,
|
||||
)
|
||||
from trajectoryhandler.frontend.jsonl2html import generate_html
|
||||
from trajectoryhandler.utils.cli import (
|
||||
adjust_model_defaults,
|
||||
extract_namespace,
|
||||
get_double_dash_flags,
|
||||
get_prefixed_pydantic_model,
|
||||
merge_dicts,
|
||||
)
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import wandb
|
||||
from atroposlib.type_definitions import UUID
|
||||
from atroposlib.utils.metrics import get_std_min_max_avg
|
||||
|
||||
|
|
@ -130,12 +146,15 @@ class BaseEnvConfig(BaseModel):
|
|||
default=2,
|
||||
description="Minimum number of items sent before logging, if 0 or less, logs every time",
|
||||
)
|
||||
include_messages: bool = Field(
|
||||
default=False,
|
||||
description="Whether to include messages in the output transmitted to the trainer",
|
||||
)
|
||||
|
||||
|
||||
class BaseEnv(ABC):
|
||||
|
||||
name = None
|
||||
env_config_cls = BaseEnvConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -176,10 +195,22 @@ class BaseEnv(ABC):
|
|||
self.checkpoint_dir = ""
|
||||
self.checkpoint_interval = -1
|
||||
if self.config.data_path_to_save_groups is not None:
|
||||
if os.path.exists(self.config.data_path_to_save_groups):
|
||||
raise FileExistsError(
|
||||
"Data path already exists! Please remove it or change it."
|
||||
# Find a suitable filename by appending _1, _2, etc. if the file already exists
|
||||
original_path = self.config.data_path_to_save_groups
|
||||
counter = 1
|
||||
path_changed = False
|
||||
while os.path.exists(self.config.data_path_to_save_groups):
|
||||
path_obj = Path(original_path)
|
||||
self.config.data_path_to_save_groups = str(
|
||||
path_obj.with_stem(f"{path_obj.stem}_{counter}")
|
||||
)
|
||||
counter += 1
|
||||
path_changed = True
|
||||
if path_changed:
|
||||
print(
|
||||
f"Changed data path to {self.config.data_path_to_save_groups} because {original_path} already exists." # noqa: E501
|
||||
)
|
||||
|
||||
self.jsonl_writer = jsonlines.open(
|
||||
self.config.data_path_to_save_groups, "w"
|
||||
) # type: jsonlines.Writer
|
||||
|
|
@ -193,7 +224,7 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
Initialize the config
|
||||
"""
|
||||
return cls.env_config_cls(), ServerBaseline()
|
||||
return BaseEnvConfig(), ServerBaseline()
|
||||
|
||||
async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]:
|
||||
raise NotImplementedError(
|
||||
|
|
@ -527,6 +558,8 @@ class BaseEnv(ABC):
|
|||
self,
|
||||
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||
item: Item = None,
|
||||
do_send_to_api: bool = True,
|
||||
abort_on_any_max_length_exceeded: bool = True,
|
||||
):
|
||||
"""
|
||||
Send the chats to the API with robust error handling and support for multiple ScoredDataGroups.
|
||||
|
|
@ -546,6 +579,7 @@ class BaseEnv(ABC):
|
|||
if self.config.ensure_scores_are_not_same:
|
||||
if len(set(scored_data["scores"])) == 1:
|
||||
# Scores are the same, don't send to API
|
||||
logger.warning("Scores are the same, skipping...")
|
||||
return
|
||||
await self.add_rollouts_for_wandb(scored_data, item)
|
||||
# Check for ref_logprobs
|
||||
|
|
@ -561,16 +595,26 @@ class BaseEnv(ABC):
|
|||
for mask in scored_data["masks"]:
|
||||
self.completion_lengths.append(len(mask))
|
||||
# Add the scores to the queue
|
||||
if any([len(x) >= self.max_token_len for x in scored_data["tokens"]]):
|
||||
if abort_on_any_max_length_exceeded and any(
|
||||
[len(x) >= self.max_token_len for x in scored_data["tokens"]]
|
||||
):
|
||||
# Don't send to API if the token length is too long
|
||||
logger.warning("Token length is too long, skipping...")
|
||||
return
|
||||
# Save data, if applicable:
|
||||
if self.config.include_messages and scored_data.get("messages") is None:
|
||||
scored_data["messages"] = [
|
||||
self.tokenizer.decode(scored_data["tokens"][i])
|
||||
for i in range(group_size)
|
||||
]
|
||||
if self.jsonl_writer is not None:
|
||||
self.jsonl_writer.write(scored_data)
|
||||
print(f"Wrote scored group to {self.config.data_path_to_save_groups}")
|
||||
# Send data with retries and error handling
|
||||
try:
|
||||
self.items_sent_this_step += 1
|
||||
await self._send_scored_data_to_api(scored_data)
|
||||
if do_send_to_api:
|
||||
self.items_sent_this_step += 1
|
||||
await self._send_scored_data_to_api(scored_data)
|
||||
except (Exception, TimeoutError) as e:
|
||||
print(f"Failed to send scored data after retries: {e}")
|
||||
|
||||
|
|
@ -733,12 +777,17 @@ class BaseEnv(ABC):
|
|||
)
|
||||
)
|
||||
|
||||
async def env_manager(self):
|
||||
async def env_manager(self, use_api=True):
|
||||
"""
|
||||
Rollout manager
|
||||
"""
|
||||
await self.setup()
|
||||
await self.setup_wandb()
|
||||
|
||||
if not self.use_api:
|
||||
await self.add_train_workers()
|
||||
return
|
||||
|
||||
await self.register_env()
|
||||
await self.get_server_info()
|
||||
# Wait for other instances to get setup :)
|
||||
|
|
@ -783,6 +832,69 @@ class BaseEnv(ABC):
|
|||
await self.add_train_workers()
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def process_manager(self):
|
||||
"""
|
||||
Process manager for running a specific number of groups
|
||||
"""
|
||||
await self.setup()
|
||||
|
||||
if self.config.use_wandb:
|
||||
random_id = "".join(random.choices(string.ascii_lowercase, k=6))
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
wandb_run_name = f"{self.name}-{current_date}-{random_id}"
|
||||
wandb.init(project=self.wandb_project, name=wandb_run_name)
|
||||
|
||||
# Initialize the processing
|
||||
self.curr_step = 0
|
||||
|
||||
print(f"Starting to process {self.n_groups_to_process} groups...")
|
||||
|
||||
# Process the required number of groups
|
||||
while self.curr_step < self.n_groups_to_process:
|
||||
# Get an item to process
|
||||
item = await self.get_next_item()
|
||||
if item is None:
|
||||
print("No more items to process")
|
||||
break
|
||||
|
||||
# Process the group
|
||||
print(f"Processing group {self.curr_step + 1}/{self.n_groups_to_process}")
|
||||
|
||||
# Collect trajectories with the specified group size
|
||||
# Override the group_size temporarily
|
||||
self.config.group_size = self.group_size_to_process
|
||||
|
||||
# Collect and process the trajectories
|
||||
to_postprocess, _ = await self.collect_trajectories(item)
|
||||
|
||||
if to_postprocess:
|
||||
# Post-process the trajectories
|
||||
processed_data = await self.postprocess_histories(to_postprocess)
|
||||
|
||||
# Save to output file (don't send to API)
|
||||
await self.handle_send_to_api(
|
||||
processed_data,
|
||||
item,
|
||||
do_send_to_api=False,
|
||||
abort_on_any_max_length_exceeded=False,
|
||||
)
|
||||
await self.wandb_log()
|
||||
|
||||
self.curr_step += 1
|
||||
print(
|
||||
f"Successfully processed group {self.curr_step}/{self.n_groups_to_process}"
|
||||
)
|
||||
else:
|
||||
print("Failed to process group, retrying...")
|
||||
|
||||
print(f"Completed processing {self.curr_step} groups")
|
||||
|
||||
# Close the output file if it's open
|
||||
if self.jsonl_writer is not None:
|
||||
self.jsonl_writer.close()
|
||||
|
||||
generate_html(self.config.data_path_to_save_groups)
|
||||
|
||||
@classmethod
|
||||
def cli(cls):
|
||||
"""
|
||||
|
|
@ -804,10 +916,8 @@ class BaseEnv(ABC):
|
|||
print()
|
||||
print(ex.message.split("error: ")[-1])
|
||||
return 2
|
||||
else:
|
||||
# For any other exception
|
||||
print(f"Error: {str(ex)}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
raise ex
|
||||
|
||||
run_and_exit(
|
||||
subcommands,
|
||||
|
|
@ -825,9 +935,18 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
|
||||
env_config, server_configs = cls.config_init()
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
server_full_prefix = f"{SERVER_MANAGER_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
||||
class CliServeConfig(
|
||||
cls.env_config_cls, OpenaiConfig, ServerManagerConfig, Cmd
|
||||
get_prefixed_pydantic_model(type(env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(OpenaiConfig, openai_full_prefix),
|
||||
get_prefixed_pydantic_model(
|
||||
ServerManagerConfig,
|
||||
server_full_prefix,
|
||||
),
|
||||
Cmd,
|
||||
):
|
||||
"""
|
||||
Configuration for the serve command.
|
||||
|
|
@ -837,8 +956,9 @@ class BaseEnv(ABC):
|
|||
def run(self) -> None:
|
||||
"""The logic to execute for the 'serve' command."""
|
||||
# Convert this config into the formats needed by BaseEnv
|
||||
if self.wandb_name is None and cls.name is not None:
|
||||
self.wandb_name = cls.name
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if getattr(self, wandb_name_attr) is None and cls.name is not None:
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
model_dumped = self.model_dump(exclude_unset=True)
|
||||
server_manager_config = ServerManagerConfig(**model_dumped)
|
||||
# Create the environment instance
|
||||
|
|
@ -863,29 +983,133 @@ class BaseEnv(ABC):
|
|||
type: The CliProcessConfig class for processing commands.
|
||||
"""
|
||||
|
||||
class CliProcessConfig(Cmd):
|
||||
PROCESS_MODE_ENV_DEFAULT_CONFIG = BaseEnvConfig(
|
||||
group_size=8,
|
||||
total_steps=2,
|
||||
ensure_scores_are_not_same=False,
|
||||
include_messages=True,
|
||||
)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = OpenaiConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
base_url=None,
|
||||
api_key=os.environ.get("OPENAI_API_KEY"),
|
||||
)
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG = ServerManagerConfig(
|
||||
slurm=False,
|
||||
testing=False,
|
||||
)
|
||||
|
||||
default_env_config, default_openai_config = cls.config_init()
|
||||
|
||||
if isinstance(default_openai_config, list):
|
||||
default_openai_config = default_openai_config[0]
|
||||
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
server_full_prefix = f"{SERVER_MANAGER_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
||||
env_config_cls_new_defaults = adjust_model_defaults(
|
||||
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
|
||||
)
|
||||
openai_config_cls_new_defaults = adjust_model_defaults(
|
||||
OpenaiConfig, PROCESS_MODE_OPENAI_DEFAULT_CONFIG
|
||||
)
|
||||
server_manager_config_cls_new_defaults = adjust_model_defaults(
|
||||
ServerManagerConfig,
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG,
|
||||
)
|
||||
|
||||
class CliProcessConfig(
|
||||
get_prefixed_pydantic_model(env_config_cls_new_defaults, env_full_prefix),
|
||||
get_prefixed_pydantic_model(
|
||||
openai_config_cls_new_defaults, openai_full_prefix
|
||||
),
|
||||
get_prefixed_pydantic_model(
|
||||
server_manager_config_cls_new_defaults, server_full_prefix
|
||||
),
|
||||
Cmd,
|
||||
):
|
||||
"""
|
||||
Configuration for the process command.
|
||||
This is a placeholder for future implementation.
|
||||
"""
|
||||
|
||||
# Add process-specific fields here
|
||||
group_size: int = Field(
|
||||
default=4, description="Number of responses per prompt"
|
||||
)
|
||||
n_groups: int = Field(default=1, description="Number of groups to process")
|
||||
output_file: str = Field(
|
||||
..., description="Path to jsonl file to write results"
|
||||
config: str | None = Field(
|
||||
default=None,
|
||||
description="Path to .yaml config file. CLI args override this.",
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
"""The logic to execute for the 'process' command."""
|
||||
# Setup environment configuration
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if getattr(self, wandb_name_attr) is None and cls.name is not None:
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
|
||||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
print(f"Loaded config from {self.config}")
|
||||
else:
|
||||
config = {}
|
||||
|
||||
cli_passed_flags = get_double_dash_flags()
|
||||
|
||||
# cli args overrides config file which overrides class defaults which overrides process mode defaults
|
||||
env_config = env_config_cls_new_defaults(
|
||||
**merge_dicts(
|
||||
default_env_config.model_dump(),
|
||||
PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(),
|
||||
config.get(ENV_NAMESPACE, {}),
|
||||
extract_namespace(
|
||||
cli_passed_flags, env_full_prefix
|
||||
), # only extract namespace for cli-passed args
|
||||
)
|
||||
)
|
||||
openai_config = openai_config_cls_new_defaults(
|
||||
**merge_dicts(
|
||||
default_openai_config.model_dump(),
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(),
|
||||
config.get(OPENAI_NAMESPACE, {}),
|
||||
extract_namespace(
|
||||
cli_passed_flags, openai_full_prefix
|
||||
), # only extract namespace for cli-passed args
|
||||
)
|
||||
)
|
||||
server_manager_config = server_manager_config_cls_new_defaults(
|
||||
**merge_dicts(
|
||||
ServerManagerConfig().model_dump(),
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(),
|
||||
config.get(SERVER_MANAGER_NAMESPACE, {}),
|
||||
extract_namespace(
|
||||
cli_passed_flags, server_full_prefix
|
||||
), # only extract namespace for cli-passed args
|
||||
)
|
||||
)
|
||||
|
||||
# Create the environment instance
|
||||
env = cls(
|
||||
config=env_config,
|
||||
server_configs=[openai_config],
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
)
|
||||
|
||||
# Set the process mode parameters
|
||||
env.process_mode = True
|
||||
env.n_groups_to_process = env_config.total_steps
|
||||
env.group_size_to_process = env_config.group_size
|
||||
|
||||
print(
|
||||
f"Processing {self.n_groups} groups of "
|
||||
f"{self.group_size} responses and "
|
||||
f"writing to {self.output_file}"
|
||||
f"Processing {env_config.total_steps} groups of "
|
||||
f"{env_config.group_size} responses and "
|
||||
f"writing to {env_config.data_path_to_save_groups}"
|
||||
)
|
||||
print("This is a placeholder implementation for the process command.")
|
||||
|
||||
asyncio.run(env.process_manager())
|
||||
|
||||
# Actual implementation would go here
|
||||
|
||||
return CliProcessConfig
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue