mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-28 17:29:30 +00:00
Merge branch 'main' into blackjack2-env
This commit is contained in:
commit
36f6822d71
9 changed files with 455 additions and 61 deletions
|
|
@ -27,9 +27,9 @@ These methods **must** be implemented in your subclass:
|
|||
|
||||
* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: The default implementation of this method runs `collect_trajectory` (see below) multiple times in parallel (controlled by `group_size`). You can override this if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input `item` (e.g. the `n` parameter in the OpenAI chat completions API) or some desired coupling of rollouts (e.g. via MCTS). It should return the collected group data and a list of backlog items.
|
||||
|
||||
* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]`**: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the `collect_trajectory` method and use the default implementation of `collect_trajectories` which runs `group_size` instances of `collect_trajectory` in parallel. This method defines the logic for a *single* logical trajectory collection step based on the input `item`.
|
||||
* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | ScoredDataItem | None, List[Item]]`**: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the `collect_trajectory` method and use the default implementation of `collect_trajectories` which runs `group_size` instances of `collect_trajectory` in parallel. This method defines the logic for a *single* logical trajectory collection step based on the input `item`.
|
||||
* **Return value**: It returns a tuple containing:\
|
||||
1. The collected data for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.\
|
||||
1. The ScoredDataItem for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.
|
||||
2. A list of new `Item` objects to be added to the backlog for future processing (e.g., follow-up prompts).\
|
||||
* **Should I define `collect_trajectory` or override `collect_trajectories`?** If you've got some way to generate your group more efficiently than a bunch of separate but parallel calls to `collect_trajectory`, or if your rollouts aren't independent as in MCTS, you should override `collect_trajectories`. If simplicity and iteration speed is more valuable than efficiency (e.g. at the start of a development cycle) and your rollouts are independent then `collect_trajectory` is for you.
|
||||
|
||||
|
|
|
|||
|
|
@ -61,6 +61,17 @@ class ScoredDataGroup(TypedDict):
|
|||
overrides: Optional[List[Dict]]
|
||||
|
||||
|
||||
class ScoredDataItem(TypedDict):
|
||||
tokens: List[int]
|
||||
masks: List[int]
|
||||
scores: float
|
||||
advantages: Optional[List[float]]
|
||||
ref_logprobs: Optional[List[float]]
|
||||
messages: Optional[List[Message]]
|
||||
group_overrides: Optional[Dict]
|
||||
overrides: Optional[Dict]
|
||||
|
||||
|
||||
class EvalHandlingEnum(Enum):
|
||||
"""
|
||||
Enum for handling evals.
|
||||
|
|
@ -229,7 +240,9 @@ class BaseEnv(ABC):
|
|||
"""
|
||||
return cls.env_config_cls(), ServerBaseline()
|
||||
|
||||
async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]:
|
||||
async def collect_trajectory(
|
||||
self, item: Item
|
||||
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
|
||||
raise NotImplementedError(
|
||||
"Handle env single method must be implemented in subclass "
|
||||
)
|
||||
|
|
@ -249,13 +262,38 @@ class BaseEnv(ABC):
|
|||
for _ in range(self.config.group_size):
|
||||
tasks.append(self.collect_trajectory(item))
|
||||
results = await asyncio.gather(*tasks)
|
||||
if any(not isinstance(result[0], dict) for result in results):
|
||||
logging.error("something wasn't a ScoredDataItem")
|
||||
raise ValueError(
|
||||
"collect_trajectory must return a ScoredDataItem or None to use the default "
|
||||
"collect_trajectories method"
|
||||
)
|
||||
backlog = []
|
||||
to_postprocess = []
|
||||
to_postprocess = ScoredDataGroup()
|
||||
to_postprocess["tokens"] = []
|
||||
to_postprocess["masks"] = []
|
||||
to_postprocess["scores"] = []
|
||||
to_postprocess["advantages"] = []
|
||||
to_postprocess["ref_logprobs"] = []
|
||||
to_postprocess["messages"] = []
|
||||
to_postprocess["group_overrides"] = {}
|
||||
to_postprocess["overrides"] = []
|
||||
print("Processing results")
|
||||
for result in results:
|
||||
if result[0] is not None:
|
||||
to_postprocess.append(result[0])
|
||||
to_postprocess["tokens"].append(result[0]["tokens"])
|
||||
to_postprocess["masks"].append(result[0]["masks"])
|
||||
to_postprocess["scores"].append(result[0]["scores"])
|
||||
if result[0].get("advantages", None) is not None:
|
||||
to_postprocess["advantages"].append(result[0]["advantages"])
|
||||
if result[0].get("ref_logprobs", None) is not None:
|
||||
to_postprocess["ref_logprobs"].append(result[0]["ref_logprobs"])
|
||||
if result[0].get("messages", None) is not None:
|
||||
to_postprocess["messages"].append(result[0]["messages"])
|
||||
if result[0].get("group_overrides", None) is not None:
|
||||
to_postprocess["group_overrides"].update(result[0]["group_overrides"])
|
||||
if result[0].get("overrides", None) is not None:
|
||||
to_postprocess["overrides"].append(result[0]["overrides"])
|
||||
backlog.extend(result[1])
|
||||
random.shuffle(backlog)
|
||||
return to_postprocess, backlog
|
||||
|
||||
async def postprocess_histories(
|
||||
|
|
|
|||
|
|
@ -4,7 +4,17 @@ import logging
|
|||
from typing import Any, List, Optional, Union
|
||||
|
||||
import scipy
|
||||
import torch
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError as e:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning(
|
||||
"torch not installed, please install atroposlib[rewardfns] to use this reward function"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
from .registry import registry
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Adjust the import below if your functions are in a different module.
|
||||
from atroposlib.utils.advantages import (
|
||||
|
|
@ -23,9 +23,9 @@ def test_allclose_to_first_vector():
|
|||
"""Test that return_vector=True returns a tensor of booleans."""
|
||||
values = [1.0, 1.000000001, 1.000000002]
|
||||
result = allclose_to_first(values, return_vector=True)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
assert isinstance(result, np.ndarray)
|
||||
# All comparisons should be True.
|
||||
assert torch.all(result)
|
||||
assert np.all(result)
|
||||
|
||||
|
||||
def test_allclose_to_first_not_close():
|
||||
|
|
@ -74,15 +74,15 @@ def test_compute_stats_jagged():
|
|||
|
||||
def test_compute_discounted_returns():
|
||||
"""Test compute_discounted_returns with a tensor input."""
|
||||
rewards = torch.tensor([1.0, 1.0, 1.0])
|
||||
rewards = np.array([1.0, 1.0, 1.0])
|
||||
gamma = 0.9
|
||||
returns = compute_discounted_returns(rewards, gamma)
|
||||
# For a 3-element vector:
|
||||
# t=2: 1.0
|
||||
# t=1: 1.0 + 0.9*1.0 = 1.9
|
||||
# t=0: 1.0 + 0.9*1.9 = 2.71
|
||||
expected = torch.tensor([2.71, 1.9, 1.0])
|
||||
assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8)
|
||||
expected = np.array([2.71, 1.9, 1.0])
|
||||
assert np.allclose(returns, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
def test_compute_discounted_returns_list_input():
|
||||
|
|
@ -90,8 +90,8 @@ def test_compute_discounted_returns_list_input():
|
|||
rewards = [1, 1, 1]
|
||||
gamma = 0.0 # With gamma=0, the returns should equal the rewards.
|
||||
returns = compute_discounted_returns(rewards, gamma)
|
||||
expected = torch.tensor([1.0, 1.0, 1.0])
|
||||
assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8)
|
||||
expected = np.array([1.0, 1.0, 1.0])
|
||||
assert np.allclose(returns, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
def test_compute_grpo_process_supervision_advantages_cumsum():
|
||||
|
|
|
|||
|
|
@ -1,31 +1,32 @@
|
|||
from typing import Sequence
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from atroposlib.type_definitions import number
|
||||
|
||||
TensorLike = torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]
|
||||
NumpyArrayLike = np.ndarray | Sequence[np.ndarray] | Sequence[Sequence]
|
||||
# Type alias for vector of bools
|
||||
BoolVector = torch.Tensor
|
||||
BoolVector = np.ndarray
|
||||
|
||||
|
||||
def allclose_to_first(
|
||||
values: TensorLike,
|
||||
# values: TensorLike,
|
||||
values: NumpyArrayLike,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False,
|
||||
return_vector: bool = False,
|
||||
) -> BoolVector | bool:
|
||||
"""
|
||||
Check if all tensors in `values` are close to the first tensor `values[0]` using a vectorized approach.
|
||||
Check if all arrays in `values` are close to the first array `values[0]` using a vectorized approach.
|
||||
|
||||
If `return_vector` is False (default), returns a single boolean indicating whether
|
||||
every tensor is close to the first tensor. If `return_vector` is True, returns a list
|
||||
of booleans where each element corresponds to whether the respective tensor in
|
||||
`values` is close to the first tensor. The first element is always True.
|
||||
every array is close to the first array. If `return_vector` is True, returns a list
|
||||
of booleans where each element corresponds to whether the respective array in
|
||||
`values` is close to the first array. The first element is always True.
|
||||
|
||||
Args:
|
||||
values (torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]):
|
||||
values (np.ndarray | Sequence[np.ndarray] | Sequence[Sequence]):
|
||||
Nested list of values to compare. Must be rectangular, but not necessarily 2D.
|
||||
rtol (float, optional): Relative tolerance. Defaults to 1e-05.
|
||||
atol (float, optional): Absolute tolerance. Defaults to 1e-08.
|
||||
|
|
@ -35,24 +36,22 @@ def allclose_to_first(
|
|||
|
||||
Returns:
|
||||
bool or BoolVector:
|
||||
- If `return_vector` is False, returns True if all tensors are close to the first tensor;
|
||||
- If `return_vector` is False, returns True if all arrays are close to the first array;
|
||||
otherwise, returns False.
|
||||
- If `return_vector` is True, returns a 1D tensor of bools where the first element is True
|
||||
(as the reference tensor is trivially close to itself), and each subsequent element indicates
|
||||
whether the corresponding tensor is close to the first tensor.
|
||||
- If `return_vector` is True, returns a 1D array of bools where the first element is True
|
||||
(as the reference array is trivially close to itself), and each subsequent element indicates
|
||||
whether the corresponding array is close to the first array.
|
||||
"""
|
||||
if not isinstance(values, torch.Tensor):
|
||||
values = torch.tensor(values)
|
||||
if not isinstance(values, np.ndarray):
|
||||
values = np.array(values)
|
||||
|
||||
reference = values[0]
|
||||
is_close = torch.isclose(
|
||||
values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||
)
|
||||
is_close = np.isclose(values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||
|
||||
# flatten dimensions after first
|
||||
result_vector = torch.all(is_close.view(is_close.size(0), -1), dim=1)
|
||||
result_vector = np.all(is_close.reshape(is_close.shape[0], -1), axis=1)
|
||||
|
||||
return result_vector if return_vector else bool(torch.all(result_vector))
|
||||
return result_vector if return_vector else bool(np.all(result_vector))
|
||||
|
||||
|
||||
def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]:
|
||||
|
|
@ -104,23 +103,23 @@ def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]:
|
|||
return {"mean": mean, "var": variance}
|
||||
|
||||
|
||||
def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Tensor:
|
||||
def compute_discounted_returns(rewards: np.ndarray, gamma: float) -> np.ndarray:
|
||||
"""Compute discounted returns from a 1D vector of rewards.
|
||||
|
||||
Given a list or torch tensor of rewards and a discount factor, this function computes
|
||||
Given a list or numpy array of rewards and a discount factor, this function computes
|
||||
the discounted return at each timestep. The discounted return at time t is defined as:
|
||||
G_t = rewards[t] + gamma * rewards[t+1] + gamma^2 * rewards[t+2] + ...
|
||||
|
||||
Args:
|
||||
rewards (list[float] or torch.Tensor): A 1D list or tensor of rewards.
|
||||
rewards (list[float] or np.ndarray): A 1D list or array of rewards.
|
||||
gamma (float): The discount factor (should be between 0 and 1).
|
||||
|
||||
Returns:
|
||||
list[float]: A list containing the discounted returns for each timestep.
|
||||
"""
|
||||
if not isinstance(rewards, torch.Tensor):
|
||||
rewards = torch.tensor(rewards, dtype=torch.float)
|
||||
discounted_returns = torch.empty_like(rewards)
|
||||
if not isinstance(rewards, np.ndarray):
|
||||
rewards = np.array(rewards, dtype=np.float32) # Use float32 for numpy default
|
||||
discounted_returns = np.empty_like(rewards)
|
||||
running_return = 0.0
|
||||
|
||||
for t in reversed(range(len(rewards))):
|
||||
|
|
@ -132,7 +131,7 @@ def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Ten
|
|||
|
||||
def compute_grpo_process_supervision_advantages(
|
||||
rewards: Sequence[Sequence[number]], gamma: float = None, std_tol: float = 1e-8
|
||||
) -> list[torch.Tensor]:
|
||||
) -> list[np.ndarray]:
|
||||
"""
|
||||
Given a (possibly jagged) list of list of rewards, compute advantages for GRPO.
|
||||
|
||||
|
|
@ -144,7 +143,7 @@ def compute_grpo_process_supervision_advantages(
|
|||
std_tol (float): The tolerance for the standard deviation.
|
||||
|
||||
Returns:
|
||||
A list of tensors of advantages.
|
||||
A list of arrays of advantages.
|
||||
|
||||
Raises:
|
||||
ValueError: If the standard deviation of the flattened rewards is smaller than the tolerance.
|
||||
|
|
@ -155,13 +154,11 @@ def compute_grpo_process_supervision_advantages(
|
|||
if std < std_tol:
|
||||
raise ValueError(f"`std` is smaller than tolerance of {std_tol}.")
|
||||
|
||||
normalized_rewards = [
|
||||
(torch.tensor(trajectory) - mean) / std for trajectory in rewards
|
||||
]
|
||||
normalized_rewards = [(np.array(trajectory) - mean) / std for trajectory in rewards]
|
||||
|
||||
if gamma is None:
|
||||
advantages = [
|
||||
trajectory.flip(dims=[0]).cumsum(dim=0).flip(dims=[0])
|
||||
np.flip(np.cumsum(np.flip(trajectory, axis=0), axis=0), axis=0)
|
||||
for trajectory in normalized_rewards
|
||||
]
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from atroposlib.type_definitions import Message
|
||||
|
|
@ -39,7 +39,7 @@ def tokenize_for_trainer(
|
|||
# (e.g. current date). e.g. consider a system prompt that depends on the current date and a run that crosses
|
||||
# midnight from 3/9 to 3/10 under a tokenizer that tokenizes 3/9 and 3/10 with a different number of tokens.
|
||||
|
||||
masks = torch.ones(len(tokens), dtype=torch.long) * -100
|
||||
masks = np.ones(len(tokens), dtype=np.int64) * -100
|
||||
|
||||
for i, msg in enumerate(chat):
|
||||
if msg["role"] in UNMASKED_ROLES:
|
||||
|
|
@ -51,7 +51,7 @@ def tokenize_for_trainer(
|
|||
)
|
||||
start_idx = len(prefix_tokens)
|
||||
end_idx = len(unmasked_tokens)
|
||||
masks[start_idx:end_idx] = torch.tensor(unmasked_tokens[start_idx:])
|
||||
masks[start_idx:end_idx] = np.array(unmasked_tokens[start_idx:])
|
||||
|
||||
masks = masks.tolist()
|
||||
if finish_reason == "length":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue