7.5 KiB
Base Environment (BaseEnv)
The BaseEnv class (located in trajectoryhandler/envs/base.py) provides a foundation for creating custom reinforcement learning environments that interact with Atropos. When creating your own environment, you will typically subclass BaseEnv and implement several key methods.
Core Methods to Implement
These methods must be implemented in your subclass:
-
async def setup(self): This method is called once at the beginning of the environment's lifecycle (env_manager). Use it for any initial setup required for your specific environment, such as loading datasets, initializing models, or connecting to external resources. -
async def get_next_item(self) -> Item: This method is responsible for generating or retrieving the next piece of data (prompt, state, etc.) that will be used to start a new trajectory collection. If no more items are available or should be generated, it can returnNoneto signal the worker to pause. -
async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]: This method defines the logic for a single logical trajectory collection step based on the inputitem. \- How it relates to multiple generations: The
BaseEnvusescollect_trajectoriesto run this method multiple times in parallel (controlled bygroup_size) to gather a batch of trajectories. \ - Your implementation: You can implement this method to generate one response/trajectory per call.\
- Return value: It returns a tuple containing:\
- The collected data for this step (one trajectory). This data can be processed further in
postprocess_histories, if you require additional filtering right before sending to the API.\ - A list of new
Itemobjects to be added to the backlog for future processing (e.g., follow-up prompts).\
- The collected data for this step (one trajectory). This data can be processed further in
- How it relates to multiple generations: The
-
async def evaluate(self, *args, **kwargs): This method is called periodically (controlled bysteps_per_evalin the config) to perform evaluation runs. You define the evaluation logic here. The base class provides an example usingself.eval_workersfor parallel evaluation tasks, but you can implement any evaluation procedure suitable for your environment.
Optional Methods to Override
These methods have default implementations or are optional based on your needs:
-
async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]: The default implementation of this method runscollect_trajectorymultiple times in parallel (controlled bygroup_size). You can override this instead ofcollect_trajectoryif you have a more efficient way to generate the entire group of responses/trajectories at once based on the inputitem. It should return the collected group data and a list of backlog items. -
async def postprocess_histories(self, trajectories: Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]) -> Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]]]: This method is called aftercollect_trajectoriesand before the data is sent to the training server. It receives the collected data from the parallel runs (or your customcollect_trajectoriesimplementation). Use this to perform final processing, scoring, or formatting you may require before sending to the server. You usually won't need this. -
async def wandb_log(self, wandb_metrics: Optional[Dict] = None): Called periodically to log metrics to Weights & Biases. If you override this to add custom metrics, ensure you callsuper().wandb_log(wandb_metrics)at the end of your implementation. This ensures that the base class's performance metrics and rollout tables are also logged.async def wandb_log(self, wandb_metrics: Optional[Dict] = None): if wandb_metrics is None: wandb_metrics = {} # Add your custom metrics wandb_metrics['my_custom_metric'] = calculate_my_metric() # ... add more metrics # Call the parent method to log base metrics await super().wandb_log(wandb_metrics) -
save_checkpoint(self, step, data=None): The base class calls this method automatically at checkpoint intervals determined by the server. It saves the provideddatadictionary (which you might populate with environment-specific state) to a JSON file. You can override this to customize what data is saved or how it's saved (e.g., using a different format or location), but the triggering mechanism remains automatic. -
@classmethod config_init(cls) -> Tuple[BaseEnvConfig, Union[ServerBaseline, List[OpenaiConfig]]]: This class method is used by the defaultget_cli_serve_config_clsimplementation to get the initial environment configuration (BaseEnvConfigsubclass) and server configurations (ServerBaselineorList[OpenaiConfig]) when setting up theservecommand. The default implementation returnscls.env_config_cls(), ServerBaseline(). You might override this if your environment requires different default configurations or specific server setups (like multipleOpenaiConfiginstances) when run via the CLIservecommand. -
async def cleanup(self): Called after each call tohandle_env. You can implement this for any cleanup needed after processing a single item, though it's often not required.
Provided Functionality
BaseEnv provides several helpful features:
- Parallel Trajectory Collection (
collect_trajectories): The base implementation runs yourcollect_trajectorymethod multiple times in parallel (based ongroup_size) and gathers the results. You can overridecollect_trajectoriesdirectly for custom group generation logic (see Optional Methods). - Server Interaction: Handles registration with the rollout server, fetching configuration (like
batch_size), sending scored data (handle_send_to_apiwith retries), and status updates. - WandB Integration: Sets up WandB logging (if enabled) based on server information and provides the
wandb_loghook for custom metrics (remember to callsuper().wandb_log()). It uses helper methodsadd_rollouts_for_wandb(to temporarily store rollout data) andcreate_rollout_table(to format the data into awandb.Table). You can override either of these helpers for custom logging behavior (e.g., changing what data is stored or how the final table is structured). - Checkpointing:
- The environment automatically triggers checkpoint saves based on the
checkpoint_intervalreceived from the server, calling thesave_checkpointmethod (see Optional Methods). load_checkpoint(self): Loads data from the checkpoint file corresponding to the environment'scurr_step. It attempts to restore attributes of the environment object based on the keys in the loaded JSON data. This is called automatically ifcurr_step > 0during registration.
- The environment automatically triggers checkpoint saves based on the
- Worker Management: Manages asynchronous worker tasks for collecting trajectories (
add_train_workers,handle_env). - Performance Monitoring: Tracks and logs various performance statistics (task durations, worker counts, etc.).
- CLI Integration: Provides a
cli()class method usingpydantic-clito easily create command-line interfaces for your environment (e.g.,python your_env_module.py serve --port 8001 ...). Seeget_cli_serve_config_clsandget_cli_process_config_cls.
By implementing the required methods and optionally overriding others, you can create diverse environments that leverage the distributed training infrastructure provided by the Atropos framework.