diff --git a/README.md b/README.md index 0ae8b8ae..2759cdc7 100644 --- a/README.md +++ b/README.md @@ -22,16 +22,17 @@ -Atropos is a Language Model Reinforcement Learning Environments framework for collecting and evaluating LLM trajectories through diverse environments including: +--- +## What is Atropos? +Atropos is an environment microservice framework for async RL with LLMs. + +Atropos encompasses both environments, which are set up as services, and a trajectory API for the environments to send data to and for the trainer to pull batches from. + +![image](https://github.com/user-attachments/assets/8ce52994-b219-49d6-970c-58a477f36151)
-| Environment Type | Examples | Purpose | -|---------------------------|--------------------------------------------|----------------------------------------------------| -| 📚 Dataset environments | GSM8K, MMLU | Evaluate and improve LLM performance on static data| -| 🎮 Online environments | Crosswords, Hangman | Train LLMs through interactive game-based learning | -| 🤖 RLAIF and RLHF | LLM Judge/Reward Models | Fine-tune LLMs using human feedback and alignment | -| 🔄 Multi-Turn RL | deepresearch, internal tool calling | Train LLMs on complex multi-step interactions | + *Here is a diagram of how Atropos' components can interact with a trainer & inference server to complete the RL loop (trainer & inference engine not included with the atropos package)*
@@ -45,6 +46,19 @@ Atropos is a robust, scalable framework for **Reinforcement Learning Environment The goal: provide a flexible, scalable, and standardized platform to accelerate LLM-based RL research across diverse, interactive settings. +The framework supports collecting, distributing and evaluating LLM trajectories through diverse environments including: + +
+ +| Environment Type | Examples | Purpose | +|---------------------------|--------------------------------------------|----------------------------------------------------| +| 📚 Dataset environments | GSM8K, MMLU | Evaluate and improve LLM performance on static data| +| 🎮 Online environments | Crosswords, Hangman | Train LLMs through interactive game-based learning | +| 🤖 RLAIF and RLHF | LLM Judge/Reward Models | Fine-tune LLMs using human feedback and alignment | +| 🔄 Multi-Turn RL | deepresearch, internal tool calling | Train LLMs on complex multi-step interactions | + +
+ ## 🎉 Upcoming Atropos Hackathon: LLM RL Environments Join us in San Francisco on May 18th, 2025 for an exciting hackathon focused on building and experimenting with LLM RL Environments! This in-person event will bring together researchers and developers interested in advancing the field of LLM reinforcement learning. diff --git a/atroposlib/envs/README.md b/atroposlib/envs/README.md index c85e821c..b7c609fd 100644 --- a/atroposlib/envs/README.md +++ b/atroposlib/envs/README.md @@ -27,9 +27,9 @@ These methods **must** be implemented in your subclass: * **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: The default implementation of this method runs `collect_trajectory` (see below) multiple times in parallel (controlled by `group_size`). You can override this if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input `item` (e.g. the `n` parameter in the OpenAI chat completions API) or some desired coupling of rollouts (e.g. via MCTS). It should return the collected group data and a list of backlog items. -* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]`**: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the `collect_trajectory` method and use the default implementation of `collect_trajectories` which runs `group_size` instances of `collect_trajectory` in parallel. This method defines the logic for a *single* logical trajectory collection step based on the input `item`. +* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | ScoredDataItem | None, List[Item]]`**: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the `collect_trajectory` method and use the default implementation of `collect_trajectories` which runs `group_size` instances of `collect_trajectory` in parallel. This method defines the logic for a *single* logical trajectory collection step based on the input `item`. * **Return value**: It returns a tuple containing:\ - 1. The collected data for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.\ + 1. The ScoredDataItem for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API. 2. A list of new `Item` objects to be added to the backlog for future processing (e.g., follow-up prompts).\ * **Should I define `collect_trajectory` or override `collect_trajectories`?** If you've got some way to generate your group more efficiently than a bunch of separate but parallel calls to `collect_trajectory`, or if your rollouts aren't independent as in MCTS, you should override `collect_trajectories`. If simplicity and iteration speed is more valuable than efficiency (e.g. at the start of a development cycle) and your rollouts are independent then `collect_trajectory` is for you. diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 87a640ab..719eb59a 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -61,6 +61,17 @@ class ScoredDataGroup(TypedDict): overrides: Optional[List[Dict]] +class ScoredDataItem(TypedDict): + tokens: List[int] + masks: List[int] + scores: float + advantages: Optional[List[float]] + ref_logprobs: Optional[List[float]] + messages: Optional[List[Message]] + group_overrides: Optional[Dict] + overrides: Optional[Dict] + + class EvalHandlingEnum(Enum): """ Enum for handling evals. @@ -229,7 +240,9 @@ class BaseEnv(ABC): """ return cls.env_config_cls(), ServerBaseline() - async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]: + async def collect_trajectory( + self, item: Item + ) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]: raise NotImplementedError( "Handle env single method must be implemented in subclass " ) @@ -249,13 +262,38 @@ class BaseEnv(ABC): for _ in range(self.config.group_size): tasks.append(self.collect_trajectory(item)) results = await asyncio.gather(*tasks) + if any(not isinstance(result[0], dict) for result in results): + logging.error("something wasn't a ScoredDataItem") + raise ValueError( + "collect_trajectory must return a ScoredDataItem or None to use the default " + "collect_trajectories method" + ) backlog = [] - to_postprocess = [] + to_postprocess = ScoredDataGroup() + to_postprocess["tokens"] = [] + to_postprocess["masks"] = [] + to_postprocess["scores"] = [] + to_postprocess["advantages"] = [] + to_postprocess["ref_logprobs"] = [] + to_postprocess["messages"] = [] + to_postprocess["group_overrides"] = {} + to_postprocess["overrides"] = [] + print("Processing results") for result in results: - if result[0] is not None: - to_postprocess.append(result[0]) + to_postprocess["tokens"].append(result[0]["tokens"]) + to_postprocess["masks"].append(result[0]["masks"]) + to_postprocess["scores"].append(result[0]["scores"]) + if result[0].get("advantages", None) is not None: + to_postprocess["advantages"].append(result[0]["advantages"]) + if result[0].get("ref_logprobs", None) is not None: + to_postprocess["ref_logprobs"].append(result[0]["ref_logprobs"]) + if result[0].get("messages", None) is not None: + to_postprocess["messages"].append(result[0]["messages"]) + if result[0].get("group_overrides", None) is not None: + to_postprocess["group_overrides"].update(result[0]["group_overrides"]) + if result[0].get("overrides", None) is not None: + to_postprocess["overrides"].append(result[0]["overrides"]) backlog.extend(result[1]) - random.shuffle(backlog) return to_postprocess, backlog async def postprocess_histories( diff --git a/atroposlib/envs/reward_fns/cosine_scaled_reward.py b/atroposlib/envs/reward_fns/cosine_scaled_reward.py index 0b620abe..3a34198b 100644 --- a/atroposlib/envs/reward_fns/cosine_scaled_reward.py +++ b/atroposlib/envs/reward_fns/cosine_scaled_reward.py @@ -4,7 +4,17 @@ import logging from typing import Any, List, Optional, Union import scipy -import torch + +try: + import torch +except ImportError as e: + logger = logging.getLogger(__name__) + logger.warning( + "torch not installed, please install atroposlib[rewardfns] to use this reward function" + ) + raise e + + from transformers import AutoModel, AutoTokenizer from .registry import registry diff --git a/atroposlib/tests/test_advantages.py b/atroposlib/tests/test_advantages.py index 151ebd2b..2643f580 100644 --- a/atroposlib/tests/test_advantages.py +++ b/atroposlib/tests/test_advantages.py @@ -1,7 +1,7 @@ import math +import numpy as np import pytest -import torch # Adjust the import below if your functions are in a different module. from atroposlib.utils.advantages import ( @@ -23,9 +23,9 @@ def test_allclose_to_first_vector(): """Test that return_vector=True returns a tensor of booleans.""" values = [1.0, 1.000000001, 1.000000002] result = allclose_to_first(values, return_vector=True) - assert isinstance(result, torch.Tensor) + assert isinstance(result, np.ndarray) # All comparisons should be True. - assert torch.all(result) + assert np.all(result) def test_allclose_to_first_not_close(): @@ -74,15 +74,15 @@ def test_compute_stats_jagged(): def test_compute_discounted_returns(): """Test compute_discounted_returns with a tensor input.""" - rewards = torch.tensor([1.0, 1.0, 1.0]) + rewards = np.array([1.0, 1.0, 1.0]) gamma = 0.9 returns = compute_discounted_returns(rewards, gamma) # For a 3-element vector: # t=2: 1.0 # t=1: 1.0 + 0.9*1.0 = 1.9 # t=0: 1.0 + 0.9*1.9 = 2.71 - expected = torch.tensor([2.71, 1.9, 1.0]) - assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8) + expected = np.array([2.71, 1.9, 1.0]) + assert np.allclose(returns, expected, rtol=1e-5, atol=1e-8) def test_compute_discounted_returns_list_input(): @@ -90,8 +90,8 @@ def test_compute_discounted_returns_list_input(): rewards = [1, 1, 1] gamma = 0.0 # With gamma=0, the returns should equal the rewards. returns = compute_discounted_returns(rewards, gamma) - expected = torch.tensor([1.0, 1.0, 1.0]) - assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8) + expected = np.array([1.0, 1.0, 1.0]) + assert np.allclose(returns, expected, rtol=1e-5, atol=1e-8) def test_compute_grpo_process_supervision_advantages_cumsum(): diff --git a/atroposlib/utils/advantages.py b/atroposlib/utils/advantages.py index dcb31b60..93ec0575 100644 --- a/atroposlib/utils/advantages.py +++ b/atroposlib/utils/advantages.py @@ -1,31 +1,32 @@ from typing import Sequence -import torch +import numpy as np from atroposlib.type_definitions import number -TensorLike = torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence] +NumpyArrayLike = np.ndarray | Sequence[np.ndarray] | Sequence[Sequence] # Type alias for vector of bools -BoolVector = torch.Tensor +BoolVector = np.ndarray def allclose_to_first( - values: TensorLike, + # values: TensorLike, + values: NumpyArrayLike, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False, return_vector: bool = False, ) -> BoolVector | bool: """ - Check if all tensors in `values` are close to the first tensor `values[0]` using a vectorized approach. + Check if all arrays in `values` are close to the first array `values[0]` using a vectorized approach. If `return_vector` is False (default), returns a single boolean indicating whether - every tensor is close to the first tensor. If `return_vector` is True, returns a list - of booleans where each element corresponds to whether the respective tensor in - `values` is close to the first tensor. The first element is always True. + every array is close to the first array. If `return_vector` is True, returns a list + of booleans where each element corresponds to whether the respective array in + `values` is close to the first array. The first element is always True. Args: - values (torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]): + values (np.ndarray | Sequence[np.ndarray] | Sequence[Sequence]): Nested list of values to compare. Must be rectangular, but not necessarily 2D. rtol (float, optional): Relative tolerance. Defaults to 1e-05. atol (float, optional): Absolute tolerance. Defaults to 1e-08. @@ -35,24 +36,22 @@ def allclose_to_first( Returns: bool or BoolVector: - - If `return_vector` is False, returns True if all tensors are close to the first tensor; + - If `return_vector` is False, returns True if all arrays are close to the first array; otherwise, returns False. - - If `return_vector` is True, returns a 1D tensor of bools where the first element is True - (as the reference tensor is trivially close to itself), and each subsequent element indicates - whether the corresponding tensor is close to the first tensor. + - If `return_vector` is True, returns a 1D array of bools where the first element is True + (as the reference array is trivially close to itself), and each subsequent element indicates + whether the corresponding array is close to the first array. """ - if not isinstance(values, torch.Tensor): - values = torch.tensor(values) + if not isinstance(values, np.ndarray): + values = np.array(values) reference = values[0] - is_close = torch.isclose( - values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan - ) + is_close = np.isclose(values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan) # flatten dimensions after first - result_vector = torch.all(is_close.view(is_close.size(0), -1), dim=1) + result_vector = np.all(is_close.reshape(is_close.shape[0], -1), axis=1) - return result_vector if return_vector else bool(torch.all(result_vector)) + return result_vector if return_vector else bool(np.all(result_vector)) def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]: @@ -104,23 +103,23 @@ def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]: return {"mean": mean, "var": variance} -def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Tensor: +def compute_discounted_returns(rewards: np.ndarray, gamma: float) -> np.ndarray: """Compute discounted returns from a 1D vector of rewards. - Given a list or torch tensor of rewards and a discount factor, this function computes + Given a list or numpy array of rewards and a discount factor, this function computes the discounted return at each timestep. The discounted return at time t is defined as: G_t = rewards[t] + gamma * rewards[t+1] + gamma^2 * rewards[t+2] + ... Args: - rewards (list[float] or torch.Tensor): A 1D list or tensor of rewards. + rewards (list[float] or np.ndarray): A 1D list or array of rewards. gamma (float): The discount factor (should be between 0 and 1). Returns: list[float]: A list containing the discounted returns for each timestep. """ - if not isinstance(rewards, torch.Tensor): - rewards = torch.tensor(rewards, dtype=torch.float) - discounted_returns = torch.empty_like(rewards) + if not isinstance(rewards, np.ndarray): + rewards = np.array(rewards, dtype=np.float32) # Use float32 for numpy default + discounted_returns = np.empty_like(rewards) running_return = 0.0 for t in reversed(range(len(rewards))): @@ -132,7 +131,7 @@ def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Ten def compute_grpo_process_supervision_advantages( rewards: Sequence[Sequence[number]], gamma: float = None, std_tol: float = 1e-8 -) -> list[torch.Tensor]: +) -> list[np.ndarray]: """ Given a (possibly jagged) list of list of rewards, compute advantages for GRPO. @@ -144,7 +143,7 @@ def compute_grpo_process_supervision_advantages( std_tol (float): The tolerance for the standard deviation. Returns: - A list of tensors of advantages. + A list of arrays of advantages. Raises: ValueError: If the standard deviation of the flattened rewards is smaller than the tolerance. @@ -155,13 +154,11 @@ def compute_grpo_process_supervision_advantages( if std < std_tol: raise ValueError(f"`std` is smaller than tolerance of {std_tol}.") - normalized_rewards = [ - (torch.tensor(trajectory) - mean) / std for trajectory in rewards - ] + normalized_rewards = [(np.array(trajectory) - mean) / std for trajectory in rewards] if gamma is None: advantages = [ - trajectory.flip(dims=[0]).cumsum(dim=0).flip(dims=[0]) + np.flip(np.cumsum(np.flip(trajectory, axis=0), axis=0), axis=0) for trajectory in normalized_rewards ] else: diff --git a/atroposlib/utils/tokenize_for_trainer.py b/atroposlib/utils/tokenize_for_trainer.py index c1187fe1..949b217b 100644 --- a/atroposlib/utils/tokenize_for_trainer.py +++ b/atroposlib/utils/tokenize_for_trainer.py @@ -1,4 +1,4 @@ -import torch +import numpy as np from transformers import PreTrainedTokenizer from atroposlib.type_definitions import Message @@ -39,7 +39,7 @@ def tokenize_for_trainer( # (e.g. current date). e.g. consider a system prompt that depends on the current date and a run that crosses # midnight from 3/9 to 3/10 under a tokenizer that tokenizes 3/9 and 3/10 with a different number of tokens. - masks = torch.ones(len(tokens), dtype=torch.long) * -100 + masks = np.ones(len(tokens), dtype=np.int64) * -100 for i, msg in enumerate(chat): if msg["role"] in UNMASKED_ROLES: @@ -51,7 +51,7 @@ def tokenize_for_trainer( ) start_idx = len(prefix_tokens) end_idx = len(unmasked_tokens) - masks[start_idx:end_idx] = torch.tensor(unmasked_tokens[start_idx:]) + masks[start_idx:end_idx] = np.array(unmasked_tokens[start_idx:]) masks = masks.tolist() if finish_reason == "length": diff --git a/environments/gym_taxi.py b/environments/gym_taxi.py new file mode 100644 index 00000000..d4f239a1 --- /dev/null +++ b/environments/gym_taxi.py @@ -0,0 +1,331 @@ +from typing import Dict, List, Optional, Tuple + +import gymnasium as gym + +from atroposlib.envs.base import BaseEnv, BaseEnvConfig, OpenaiConfig, ScoredDataItem +from atroposlib.type_definitions import Item + +start_msg = """### Description +There are four designated locations in the grid world indicated by R(ed), +G(reen), Y(ellow), and B(lue). When the episode starts, the taxi starts off +at a random square and the passenger is at a random location. The taxi +drives to the passenger's location, picks up the passenger, drives to the +passenger's destination (another one of the four specified locations), and +then drops off the passenger. Once the passenger is dropped off, the episode ends. + +Map: + + +---------+ + |R: | : :G| + | : | : : | + | : : : : | + | | : | : | + |Y| : |B: | + +---------+ + +### Actions +There are 6 discrete deterministic actions: +- 0: move south (increases row index) +- 1: move north (decreases row index) +- 2: move east (increases column index) +- 3: move west (decreases column index) +- 4: pickup passenger (IF on a letter location, AND passenger is located at the same location, pickup passenger) +- 5: drop off passenger + +### Observations + +Passenger locations: +- 0: R(ed) +- 1: G(reen) +- 2: Y(ellow) +- 3: B(lue) +- 4: in taxi + +Destinations: +- 0: R(ed) (Row 0, Col 0) +- 1: G(reen) (Row 4, Col 4) +- 2: Y(ellow) (Row 0, Col 4) +- 3: B(lue) (Row 3, Col 3) + +### Instructions +Please perform the actions that will let you pick up and/or drop off the passenger. +Please respond with the action number only. +You cannot move the taxi into walls, which are displayed as | in the map. : means you are free to move through that column. + + +For an example, if the passenger is at R, and the destination is G, and the taxi is at (2, 2), then here are the following actions to solve this in the correct order: + +3 (move west) +3 (move west) +1 (move north) +1 (move north) +4 (pickup passenger) +0 (move south) +0 (move south) +2 (move east) +2 (move east) +2 (move east) +2 (move east) +0 (move south) +0 (move south) +5 (drop off passenger) + +If you are stuck, try moving to row idx 2, as there are no walls there. + +Submit your response as a number between 0 and 5 only to perform the discrete action. +Each turn we will give you the current state of the environment, and you will need to respond with the action number only from the available actions.""" # noqa: E501 + + +def decode(i): + out = [] + out.append(i % 4) + i = i // 4 + out.append(i % 5) + i = i // 5 + out.append(i % 5) + i = i // 5 + out.append(i) + assert 0 <= i < 5 + x = reversed(out) + # Making it explicit so I don't have to look into gym code + taxi_row, taxi_col, pass_idx, dest_idx = x + return taxi_row, taxi_col, pass_idx, dest_idx + + +# Note: Works for both the passenger and the destination +TO_LOC_MAP = { + 0: "R(Row 0, Col 0)", + 1: "G (Row 4, Col 4)", + 2: "Y (Row 0, Col 4)", + 3: "B (Row 3, Col 3)", + 4: "in taxi", +} +MAP_LOC = {0: (0, 0), 1: (4, 4), 2: (0, 4), 3: (3, 3)} +TO_ACTION_MAP = { + 0: "south", + 1: "north", + 2: "east", + 3: "west", + 4: "pickup", + 5: "dropoff", +} + + +def state_render_to_user_msg(last_state, state, action_mask, render): + taxi_row, taxi_col, pass_idx, dest_idx = decode(state) + if last_state is not None: + last_taxi_row, last_taxi_col, last_pass_idx, last_dest_idx = decode(last_state) + available_actions = "\n".join( + [ + f"- {i}: {TO_ACTION_MAP[i]}" + for i in range(6) + if (action_mask[i] == 1) + and ( + (i != 5) + or ( + (i == 5) + and (taxi_row == MAP_LOC[dest_idx][0]) + and (taxi_col == MAP_LOC[dest_idx][1]) + ) + ) + ] + ) + if last_state is not None: + ret_str = ( + f"Previous Taxi Location: Row: {last_taxi_row}, Col: {last_taxi_col}\n" + ) + else: + ret_str = "" + ret_str += ( + f"Current state:\nTaxi: Row: {taxi_row}, Col: {taxi_col}\nPassenger: {TO_LOC_MAP[pass_idx]}\n" + f"Destination: {TO_LOC_MAP[dest_idx]}\n\n" + f"Map:\n{render}\n\n" + f"Available actions:\n{available_actions}" + ) + if ( + (pass_idx == 4) + and (taxi_row == MAP_LOC[dest_idx][0]) + and (taxi_col == MAP_LOC[dest_idx][1]) + ): + ret_str += "\n\nPlease drop off the passenger." + elif pass_idx == 4: + ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[dest_idx]} to drop off the passenger." + elif (taxi_row == MAP_LOC[pass_idx][0]) and (taxi_col == MAP_LOC[pass_idx][1]): + ret_str += "\n\nPlease pick up the passenger." + else: + ret_str += f"\n\nPlease move the taxi to {TO_LOC_MAP[pass_idx]} to pick up the passenger." + return ret_str + + +class GymTaxiEnv(BaseEnv): + + name = "gym_taxi" + + def __init__( + self, + config: BaseEnvConfig, + server_configs: List[OpenaiConfig], + slurm=True, + testing=False, + ): + super().__init__(config, server_configs, slurm, testing) + self.percent_correct_buffer = list() + self.percent_picked_up_passenger_buffer = list() + self.eval_metrics = list() + # Add tracking for wandb visualizations + self.rollouts_for_wandb = [] + self.completion_lengths = [] + self.print_this_env = False + + @classmethod + def config_init(cls) -> Tuple[BaseEnvConfig, List[OpenaiConfig]]: + env_config = BaseEnvConfig( + tokenizer_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + group_size=32, + use_wandb=True, + rollout_server_url="http://localhost:8000", + max_token_length=8192, + wandb_name="gym_taxi", + ) + server_configs = [ + OpenaiConfig( + model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", + base_url="http://localhost:9001/v1", + api_key="x", + num_requests_for_eval=256, + ), + ] + + return env_config, server_configs + + async def wandb_log(self, wandb_metrics: Optional[Dict] = None): + if wandb_metrics is None: + wandb_metrics = {} + + # Try to calculate percent_correct, pass if there's a division by zero + try: + wandb_metrics["train/percent_correct"] = sum( + self.percent_correct_buffer + ) / len(self.percent_correct_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + try: + wandb_metrics["train/percent_picked_up_passenger"] = sum( + self.percent_picked_up_passenger_buffer + ) / len(self.percent_picked_up_passenger_buffer) + except ZeroDivisionError: + # Skip if buffer is empty + pass + + self.percent_correct_buffer = list() + self.percent_picked_up_passenger_buffer = list() + for item in self.eval_metrics: + wandb_metrics[item[0]] = item[1] + self.eval_metrics = list() + # Call the parent method to handle the server metrics + await super().wandb_log(wandb_metrics) + + async def setup(self): + self.iter = 0 + + async def evaluate(self, *args, **kwargs): + pass + + async def collect_trajectory( + self, item: Item + ) -> Tuple[Optional[ScoredDataItem], List[Item]]: + # Grab a dedicated llm server to take advantage of caching + async with self.server.dedicated_server() as server: + env = gym.make("Taxi-v3", render_mode="ansi") + state, info = env.reset(seed=item["seed"]) + last_state = None + taxi_row, taxi_col, pass_idx, dest_idx = decode(state) + init_msg = f"{start_msg}\n\n" + state_render_to_user_msg( + last_state, state, info["action_mask"], env.render() + ) + messages = [{"role": "user", "content": init_msg}] + score = -1 + while True: + if ( + len(self.tokenizer.apply_chat_template(messages)) + > self.config.max_token_length - 10 + ): + break + max_tokens = self.config.max_token_length - len( + self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True + ) + ) + chat_completions = await server.chat_completion( + messages=messages, + n=1, + max_tokens=max_tokens, + ) + choice = ( + chat_completions.choices[0] + .message.content.strip() + .replace(".", "")[-1] + ) + messages.append( + { + "role": "assistant", + "content": chat_completions.choices[0].message.content, + } + ) + if choice.isdigit() and 0 <= int(choice) <= 5: + action = int(choice) + else: + break + if info["action_mask"][action] == 0: + break + if action == 3: + # picked up passenger + score = 0 + next_state, reward, terminated, truncated, info = env.step(action) + last_state = state + state = next_state + if terminated: + score = 1 + break + messages.append( + { + "role": "user", + "content": state_render_to_user_msg( + last_state, state, info["action_mask"], env.render() + ), + } + ) + self.percent_correct_buffer.append(max(score, 0)) + self.percent_picked_up_passenger_buffer.append(1 if score >= 0 else 0) + tokens = self.tokenizer.apply_chat_template(messages) + masks = [] + for i, msg in enumerate(messages): + if i == len(messages) - 1: + masks.extend(tokens[len(masks) :]) + else: + curr_tokens = self.tokenizer.apply_chat_template( + messages[: i + 1], + add_generation_prompt=messages[i + 1]["role"] == "assistant", + ) + if messages[i]["role"] == "user": + masks.extend([-100] * (len(curr_tokens) - len(masks))) + else: + masks.extend(curr_tokens[len(masks) :]) + scored_data_item = ScoredDataItem( + messages=messages, + finish_reason=score, + tokens=tokens, + masks=masks, + scores=score, + ) + return scored_data_item, [] + + async def get_next_item(self): + next_item = {"seed": self.iter} + self.iter += 1 + return next_item + + +if __name__ == "__main__": + GymTaxiEnv.cli() diff --git a/pyproject.toml b/pyproject.toml index ddd8ffbe..1547aee7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "markdown", "numpy", "wandb", + "gymnasium", "math-verify==0.7.0", "jinja2", "nltk", @@ -23,7 +24,6 @@ dependencies = [ "polars", "aiofiles", "jsonlines", - "torch", "pydantic-cli", "hf_transfer", ] @@ -40,6 +40,9 @@ atropos-dpo-gen = "atroposlib.cli.dpo:main" all = [ "atroposlib[dev,examples]" ] +rewardfns = [ + "torch" +] dev = [ "pytest", "pytest-asyncio", @@ -48,10 +51,11 @@ dev = [ "flake8", "isort", "mypy", - 'rich', + "rich", ] examples = [ - "gradio" + "gradio", + "atroposlib[rewardfns]" ] [build-system]