mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
Merge commit '71e7a5ca27' into add-support-for-custom-api-servers
This commit is contained in:
commit
96be544228
45 changed files with 1605 additions and 494 deletions
|
|
@ -20,15 +20,12 @@ import wandb
|
|||
import yaml
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_cli import Cmd, FailedExecutionException, run_and_exit
|
||||
from rich import print as rprint
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from atroposlib.envs.constants import (
|
||||
ENV_NAMESPACE,
|
||||
NAMESPACE_SEP,
|
||||
OPENAI_NAMESPACE,
|
||||
SERVER_MANAGER_NAMESPACE,
|
||||
)
|
||||
from atroposlib.envs.constants import ENV_NAMESPACE, NAMESPACE_SEP, OPENAI_NAMESPACE
|
||||
from atroposlib.envs.server_handling.openai_server import resolve_openai_configs
|
||||
from atroposlib.frontend.jsonl2html import generate_html
|
||||
from atroposlib.type_definitions import UUID
|
||||
from atroposlib.utils.cli import (
|
||||
|
|
@ -38,6 +35,7 @@ from atroposlib.utils.cli import (
|
|||
get_prefixed_pydantic_model,
|
||||
merge_dicts,
|
||||
)
|
||||
from atroposlib.utils.io import parse_http_response
|
||||
from atroposlib.utils.metrics import get_std_min_max_avg
|
||||
|
||||
from ..type_definitions import Item, Message
|
||||
|
|
@ -63,6 +61,17 @@ class ScoredDataGroup(TypedDict):
|
|||
overrides: Optional[List[Dict]]
|
||||
|
||||
|
||||
class ScoredDataItem(TypedDict):
|
||||
tokens: List[int]
|
||||
masks: List[int]
|
||||
scores: float
|
||||
advantages: Optional[List[float]]
|
||||
ref_logprobs: Optional[List[float]]
|
||||
messages: Optional[List[Message]]
|
||||
group_overrides: Optional[Dict]
|
||||
overrides: Optional[Dict]
|
||||
|
||||
|
||||
class EvalHandlingEnum(Enum):
|
||||
"""
|
||||
Enum for handling evals.
|
||||
|
|
@ -237,7 +246,9 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
return cls.env_config_cls(), ServerBaseline(), None
|
||||
|
||||
async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]:
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
raise NotImplementedError(
|
||||
"Handle env single method must be implemented in subclass "
|
||||
)
|
||||
|
|
@ -257,13 +268,38 @@ class BaseEnv(ABC):
|
|||
for _ in range(self.config.group_size):
|
||||
tasks.append(self.collect_trajectory(item))
|
||||
results = await asyncio.gather(*tasks)
|
||||
if any(not isinstance(result[0], dict) for result in results):
|
||||
logging.error("something wasn't a ScoredDataItem")
|
||||
raise ValueError(
|
||||
"collect_trajectory must return a ScoredDataItem or None to use the default "
|
||||
"collect_trajectories method"
|
||||
)
|
||||
backlog = []
|
||||
to_postprocess = []
|
||||
to_postprocess = ScoredDataGroup()
|
||||
to_postprocess["tokens"] = []
|
||||
to_postprocess["masks"] = []
|
||||
to_postprocess["scores"] = []
|
||||
to_postprocess["advantages"] = []
|
||||
to_postprocess["ref_logprobs"] = []
|
||||
to_postprocess["messages"] = []
|
||||
to_postprocess["group_overrides"] = {}
|
||||
to_postprocess["overrides"] = []
|
||||
print("Processing results")
|
||||
for result in results:
|
||||
if result[0] is not None:
|
||||
to_postprocess.append(result[0])
|
||||
to_postprocess["tokens"].append(result[0]["tokens"])
|
||||
to_postprocess["masks"].append(result[0]["masks"])
|
||||
to_postprocess["scores"].append(result[0]["scores"])
|
||||
if result[0].get("advantages", None) is not None:
|
||||
to_postprocess["advantages"].append(result[0]["advantages"])
|
||||
if result[0].get("ref_logprobs", None) is not None:
|
||||
to_postprocess["ref_logprobs"].append(result[0]["ref_logprobs"])
|
||||
if result[0].get("messages", None) is not None:
|
||||
to_postprocess["messages"].append(result[0]["messages"])
|
||||
if result[0].get("group_overrides", None) is not None:
|
||||
to_postprocess["group_overrides"].update(result[0]["group_overrides"])
|
||||
if result[0].get("overrides", None) is not None:
|
||||
to_postprocess["overrides"].append(result[0]["overrides"])
|
||||
backlog.extend(result[1])
|
||||
random.shuffle(backlog)
|
||||
return to_postprocess, backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
|
|
@ -358,7 +394,7 @@ class BaseEnv(ABC):
|
|||
async with session.get(
|
||||
f"{self.config.rollout_server_url}/wandb_info"
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
data = await parse_http_response(resp, logger)
|
||||
self.wandb_group = data["group"]
|
||||
self.wandb_project = data["project"]
|
||||
if self.wandb_project is None:
|
||||
|
|
@ -386,7 +422,7 @@ class BaseEnv(ABC):
|
|||
"weight": self.config.inference_weight,
|
||||
},
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
data = await parse_http_response(resp, logger)
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering env: {e}")
|
||||
|
|
@ -427,7 +463,7 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{self.config.rollout_server_url}/info") as resp:
|
||||
data = await resp.json()
|
||||
data = await parse_http_response(resp, logger)
|
||||
if data["batch_size"] != -1:
|
||||
# update the batch size
|
||||
self.config.batch_size = data["batch_size"]
|
||||
|
|
@ -710,7 +746,7 @@ class BaseEnv(ABC):
|
|||
f"{self.config.rollout_server_url}/status-env",
|
||||
json={"env_id": self.env_id},
|
||||
) as resp:
|
||||
self.status_dict = await resp.json()
|
||||
self.status_dict = await parse_http_response(resp, logger)
|
||||
new_weight = self.status_dict["env_weight"]
|
||||
max_num_workers = self.config.max_num_workers
|
||||
if max_num_workers == -1:
|
||||
|
|
@ -964,53 +1000,158 @@ class BaseEnv(ABC):
|
|||
Returns:
|
||||
type: The CliServeConfig class for serving commands.
|
||||
"""
|
||||
# Get the default configurations defined by the specific environment class
|
||||
configs_and_maybe_server_class = cls.config_init()
|
||||
if len(configs_and_maybe_server_class) == 2:
|
||||
env_config, server_configs = configs_and_maybe_server_class
|
||||
default_env_config, default_server_configs = configs_and_maybe_server_class
|
||||
server_class = None
|
||||
else:
|
||||
env_config, server_configs, server_class = configs_and_maybe_server_class
|
||||
default_env_config, default_server_configs, server_class = (
|
||||
configs_and_maybe_server_class
|
||||
)
|
||||
|
||||
# Define namespace prefixes for CLI arguments and YAML keys
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
||||
# Define the CLI configuration class dynamically
|
||||
class CliServeConfig(
|
||||
get_prefixed_pydantic_model(type(env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(APIServerConfig, openai_full_prefix),
|
||||
ServerManagerConfig,
|
||||
get_prefixed_pydantic_model(type(default_env_config), env_full_prefix),
|
||||
get_prefixed_pydantic_model(
|
||||
APIServerConfig, openai_full_prefix
|
||||
), # Use APIServerConfig for CLI args
|
||||
ServerManagerConfig, # ServerManager args are not namespaced by default
|
||||
Cmd,
|
||||
):
|
||||
"""
|
||||
Configuration for the serve command.
|
||||
This combines BaseEnvConfig and APIServerConfig into a single command.
|
||||
Supports overrides via YAML config file and CLI arguments.
|
||||
Order of precedence: CLI > YAML > Class Defaults.
|
||||
"""
|
||||
|
||||
config: str | None = Field(
|
||||
default=None,
|
||||
description="Path to .yaml config file. CLI args override this.",
|
||||
)
|
||||
|
||||
def run(self) -> None:
|
||||
"""The logic to execute for the 'serve' command."""
|
||||
# Convert this config into the formats needed by BaseEnv
|
||||
# Set default wandb name if not provided and class has a name
|
||||
# Note: This modifies the 'self' instance based on CLI args before full parsing.
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if getattr(self, wandb_name_attr) is None and cls.name is not None:
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
model_dumped = self.model_dump(exclude_unset=True)
|
||||
server_manager_config = ServerManagerConfig(**model_dumped)
|
||||
# Create the environment instance
|
||||
try:
|
||||
env = cls(
|
||||
config=env_config,
|
||||
server_configs=server_configs,
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
server_class=server_class,
|
||||
|
||||
# Load configuration from YAML file if specified
|
||||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
yaml_config = yaml.safe_load(f)
|
||||
print(f"Loaded config from {self.config}")
|
||||
else:
|
||||
yaml_config = {}
|
||||
|
||||
# Get CLI flags passed with double dashes (e.g., --env--foo bar)
|
||||
cli_passed_flags = get_double_dash_flags()
|
||||
|
||||
# --- Configuration Merging ---
|
||||
# Priority: CLI > YAML > Class Defaults
|
||||
|
||||
# 1. Environment Configuration
|
||||
env_config_dict = merge_dicts(
|
||||
default_env_config.model_dump(), # Class Defaults
|
||||
yaml_config.get(ENV_NAMESPACE, {}), # YAML config
|
||||
extract_namespace(cli_passed_flags, env_full_prefix), # CLI args
|
||||
)
|
||||
|
||||
# 2. OpenAI Configuration (used for potential overrides)
|
||||
oai_cli_passed_args = extract_namespace(
|
||||
cli_passed_flags, openai_full_prefix
|
||||
) # CLI args
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
if isinstance(default_server_configs, ServerBaseline) and (
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
raise ValueError(
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
|
||||
)
|
||||
except TypeError as e:
|
||||
warnings.warn(
|
||||
"Not supporting server_class will be deprecated soon, please add that kwarg"
|
||||
)
|
||||
env = cls(
|
||||
config=env_config,
|
||||
server_configs=server_configs,
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
if (
|
||||
isinstance(default_server_configs, list)
|
||||
and len(default_server_configs) == 1
|
||||
):
|
||||
# can't use the same var name because it shadows the class variable and we get an error
|
||||
default_openai_config_ = default_server_configs[0]
|
||||
else:
|
||||
default_openai_config_ = default_server_configs
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
else:
|
||||
openai_config_dict = {}
|
||||
|
||||
# 3. Server Manager Configuration (slurm, testing - not namespaced)
|
||||
# Extract only relevant CLI flags for ServerManager
|
||||
server_manager_cli_passed_flags = {}
|
||||
if "slurm" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||
if "testing" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["testing"] = cli_passed_flags[
|
||||
"testing"
|
||||
]
|
||||
|
||||
server_manager_yaml_dict = {}
|
||||
if "slurm" in yaml_config:
|
||||
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
|
||||
if "testing" in yaml_config:
|
||||
server_manager_yaml_dict["testing"] = yaml_config["testing"]
|
||||
|
||||
server_manager_config_dict = merge_dicts(
|
||||
ServerManagerConfig().model_dump(), # Base defaults for ServerManager
|
||||
server_manager_yaml_dict, # YAML config
|
||||
server_manager_cli_passed_flags, # CLI args
|
||||
)
|
||||
|
||||
# --- Instantiate Final Config Objects ---
|
||||
# Create instances from the merged dictionaries using the original default types where appropriate
|
||||
|
||||
# Instantiate the final environment config using its original type
|
||||
env_config = type(default_env_config)(**env_config_dict)
|
||||
|
||||
# Instantiate the final server manager config
|
||||
server_manager_config = ServerManagerConfig(
|
||||
**server_manager_config_dict
|
||||
)
|
||||
|
||||
# Determine the final server_configs, handling single, multiple servers, and overrides.
|
||||
|
||||
openai_configs = resolve_openai_configs(
|
||||
default_server_configs=default_server_configs,
|
||||
openai_config_dict=openai_config_dict,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# --- Create and Run Environment ---
|
||||
# Create the environment instance using the final, instantiated config objects
|
||||
env = cls(
|
||||
config=env_config,
|
||||
server_configs=openai_configs,
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
)
|
||||
rprint(env_config)
|
||||
rprint(openai_configs)
|
||||
|
||||
# Run the environment
|
||||
asyncio.run(env.env_manager())
|
||||
|
||||
|
|
@ -1025,11 +1166,14 @@ class BaseEnv(ABC):
|
|||
type: The CliProcessConfig class for processing commands.
|
||||
"""
|
||||
|
||||
# Define specific default configurations for the 'process' mode
|
||||
PROCESS_MODE_ENV_DEFAULT_CONFIG = BaseEnvConfig(
|
||||
group_size=8,
|
||||
total_steps=2,
|
||||
ensure_scores_are_not_same=False,
|
||||
include_messages=True,
|
||||
data_path_to_save_groups=f"data/{cls.name or 'groups'}.jsonl",
|
||||
use_wandb=True,
|
||||
)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG = APIServerConfig(
|
||||
model_name="gpt-4.1-nano",
|
||||
|
|
@ -1041,21 +1185,22 @@ class BaseEnv(ABC):
|
|||
testing=False,
|
||||
)
|
||||
|
||||
# Get the base default configurations from the specific environment class
|
||||
configs_and_maybe_server_class = cls.config_init()
|
||||
if len(configs_and_maybe_server_class) == 2:
|
||||
default_env_config, default_openai_config = configs_and_maybe_server_class
|
||||
default_env_config, default_server_configs = configs_and_maybe_server_class
|
||||
server_class = None
|
||||
else:
|
||||
default_env_config, default_openai_config, server_class = (
|
||||
default_env_config, default_server_configs, server_class = (
|
||||
configs_and_maybe_server_class
|
||||
)
|
||||
|
||||
if isinstance(default_openai_config, list):
|
||||
default_openai_config = default_openai_config[0]
|
||||
|
||||
# Define namespace prefixes
|
||||
env_full_prefix = f"{ENV_NAMESPACE}{NAMESPACE_SEP}"
|
||||
openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}"
|
||||
|
||||
# Create Pydantic model classes with the 'process' mode defaults applied.
|
||||
# These adjusted classes will be used for final instantiation.
|
||||
env_config_cls_new_defaults = adjust_model_defaults(
|
||||
type(default_env_config), PROCESS_MODE_ENV_DEFAULT_CONFIG
|
||||
)
|
||||
|
|
@ -1077,6 +1222,8 @@ class BaseEnv(ABC):
|
|||
):
|
||||
"""
|
||||
Configuration for the process command.
|
||||
Supports overrides via YAML config file and CLI arguments.
|
||||
Order of precedence: CLI > YAML > Process Mode Defaults > `config_init` defaults.
|
||||
"""
|
||||
|
||||
config: str | None = Field(
|
||||
|
|
@ -1086,42 +1233,72 @@ class BaseEnv(ABC):
|
|||
|
||||
def run(self) -> None:
|
||||
"""The logic to execute for the 'process' command."""
|
||||
# Setup environment configuration
|
||||
# Set default wandb name if not provided and class has a name
|
||||
wandb_name_attr = f"{ENV_NAMESPACE}{NAMESPACE_SEP}wandb_name"
|
||||
if getattr(self, wandb_name_attr) is None and cls.name is not None:
|
||||
if (
|
||||
getattr(self, wandb_name_attr, None) is None
|
||||
and cls.name is not None
|
||||
):
|
||||
setattr(self, wandb_name_attr, cls.name)
|
||||
|
||||
# Load configuration from YAML file if specified
|
||||
if self.config is not None:
|
||||
with open(self.config, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
yaml_config = yaml.safe_load(f)
|
||||
print(f"Loaded config from {self.config}")
|
||||
else:
|
||||
config = {}
|
||||
yaml_config = {}
|
||||
|
||||
# Get CLI flags passed with double dashes
|
||||
cli_passed_flags = get_double_dash_flags()
|
||||
|
||||
# cli args overrides config file which overrides class defaults which overrides process mode defaults
|
||||
env_config = env_config_cls_new_defaults(
|
||||
**merge_dicts(
|
||||
default_env_config.model_dump(),
|
||||
PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(),
|
||||
config.get(ENV_NAMESPACE, {}),
|
||||
extract_namespace(
|
||||
cli_passed_flags, env_full_prefix
|
||||
), # only extract namespace for cli-passed args
|
||||
)
|
||||
)
|
||||
openai_config = openai_config_cls_new_defaults(
|
||||
**merge_dicts(
|
||||
default_openai_config.model_dump(),
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(),
|
||||
config.get(OPENAI_NAMESPACE, {}),
|
||||
extract_namespace(
|
||||
cli_passed_flags, openai_full_prefix
|
||||
), # only extract namespace for cli-passed args
|
||||
)
|
||||
# --- Configuration Merging ---
|
||||
# Priority: CLI > YAML > Process Mode Defaults > `config_init` defaults
|
||||
|
||||
# 1. Environment Configuration
|
||||
env_config_dict = merge_dicts(
|
||||
default_env_config.model_dump(), # Class Defaults
|
||||
PROCESS_MODE_ENV_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
yaml_config.get(ENV_NAMESPACE, {}), # YAML config
|
||||
extract_namespace(cli_passed_flags, env_full_prefix), # CLI args
|
||||
)
|
||||
|
||||
# 2. OpenAI Configuration
|
||||
oai_cli_passed_args = extract_namespace(
|
||||
cli_passed_flags, openai_full_prefix
|
||||
) # CLI args
|
||||
yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {})
|
||||
if isinstance(default_server_configs, ServerBaseline) and (
|
||||
oai_cli_passed_args or yaml_oai_config
|
||||
):
|
||||
raise ValueError(
|
||||
"ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501
|
||||
)
|
||||
|
||||
if (
|
||||
isinstance(default_server_configs, list)
|
||||
and len(default_server_configs) == 1
|
||||
):
|
||||
# can't use the same var name because it shadows the class variable and we get an error
|
||||
default_openai_config_ = default_server_configs[0]
|
||||
else:
|
||||
default_openai_config_ = default_server_configs
|
||||
if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1:
|
||||
yaml_oai_config = yaml_oai_config[0]
|
||||
if isinstance(default_openai_config_, APIServerConfig) and isinstance(
|
||||
yaml_oai_config, dict
|
||||
):
|
||||
openai_config_dict = merge_dicts(
|
||||
default_openai_config_.model_dump(), # Default APIServerConfig (or from class init)
|
||||
PROCESS_MODE_OPENAI_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
yaml_oai_config,
|
||||
oai_cli_passed_args,
|
||||
)
|
||||
else:
|
||||
openai_config_dict = {}
|
||||
|
||||
# 3. Server Manager Configuration
|
||||
# Extract only relevant CLI flags
|
||||
server_manager_cli_passed_flags = {}
|
||||
if "slurm" in cli_passed_flags:
|
||||
server_manager_cli_passed_flags["slurm"] = cli_passed_flags["slurm"]
|
||||
|
|
@ -1130,37 +1307,69 @@ class BaseEnv(ABC):
|
|||
"testing"
|
||||
]
|
||||
|
||||
server_manager_config = server_manager_config_cls_new_defaults(
|
||||
**merge_dicts(
|
||||
ServerManagerConfig().model_dump(),
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(),
|
||||
config.get(SERVER_MANAGER_NAMESPACE, {}),
|
||||
server_manager_cli_passed_flags,
|
||||
)
|
||||
server_manager_yaml_dict = {}
|
||||
if "slurm" in yaml_config:
|
||||
server_manager_yaml_dict["slurm"] = yaml_config["slurm"]
|
||||
if "testing" in yaml_config:
|
||||
server_manager_yaml_dict["testing"] = yaml_config["testing"]
|
||||
|
||||
server_manager_config_dict = merge_dicts(
|
||||
ServerManagerConfig().model_dump(), # Base defaults
|
||||
PROCESS_MODE_SERVER_MANAGER_DEFAULT_CONFIG.model_dump(), # Process Mode Defaults
|
||||
server_manager_yaml_dict,
|
||||
server_manager_cli_passed_flags, # CLI args
|
||||
)
|
||||
|
||||
# --- Instantiate Final Config Objects ---
|
||||
# Use the classes with adjusted defaults for instantiation
|
||||
|
||||
env_config = env_config_cls_new_defaults(**env_config_dict)
|
||||
server_manager_config = server_manager_config_cls_new_defaults(
|
||||
**server_manager_config_dict
|
||||
)
|
||||
|
||||
# Determine the final server_configs, handling single, multiple servers, and overrides.
|
||||
|
||||
openai_configs = resolve_openai_configs(
|
||||
default_server_configs=default_server_configs,
|
||||
openai_config_dict=openai_config_dict,
|
||||
yaml_config=yaml_config,
|
||||
cli_passed_flags=cli_passed_flags,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
rprint(env_config)
|
||||
rprint(openai_configs)
|
||||
|
||||
# --- Create and Run Environment ---
|
||||
# Create the environment instance
|
||||
env = cls(
|
||||
config=env_config,
|
||||
server_configs=[openai_config],
|
||||
server_configs=openai_configs,
|
||||
slurm=server_manager_config.slurm,
|
||||
testing=server_manager_config.testing,
|
||||
server_class=server_class,
|
||||
)
|
||||
|
||||
# Set the process mode parameters
|
||||
# Set specific parameters for process mode on the environment instance
|
||||
env.process_mode = True
|
||||
env.n_groups_to_process = env_config.total_steps
|
||||
env.group_size_to_process = env_config.group_size
|
||||
|
||||
# Validate that an output path is set (should have a default from PROCESS_MODE_ENV_DEFAULT_CONFIG)
|
||||
if env_config.data_path_to_save_groups is None:
|
||||
# This check might be redundant if the default is always set, but good practice.
|
||||
raise ValueError(
|
||||
"data_path_to_save_groups must be set for process mode"
|
||||
)
|
||||
|
||||
print(
|
||||
f"Processing {env_config.total_steps} groups of "
|
||||
f"{env_config.group_size} responses and "
|
||||
f"writing to {env_config.data_path_to_save_groups}"
|
||||
)
|
||||
|
||||
# Run the environment's asynchronous process manager function
|
||||
asyncio.run(env.process_manager())
|
||||
|
||||
# Actual implementation would go here
|
||||
|
||||
return CliProcessConfig
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue