Merge branch 'main' into blackjack2-env

This commit is contained in:
Shannon Sands 2025-05-13 07:54:04 +10:00
commit 36f6822d71
9 changed files with 455 additions and 61 deletions

View file

@ -27,9 +27,9 @@ These methods **must** be implemented in your subclass:
* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: The default implementation of this method runs `collect_trajectory` (see below) multiple times in parallel (controlled by `group_size`). You can override this if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input `item` (e.g. the `n` parameter in the OpenAI chat completions API) or some desired coupling of rollouts (e.g. via MCTS). It should return the collected group data and a list of backlog items.
* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]`**: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the `collect_trajectory` method and use the default implementation of `collect_trajectories` which runs `group_size` instances of `collect_trajectory` in parallel. This method defines the logic for a *single* logical trajectory collection step based on the input `item`.
* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | ScoredDataItem | None, List[Item]]`**: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the `collect_trajectory` method and use the default implementation of `collect_trajectories` which runs `group_size` instances of `collect_trajectory` in parallel. This method defines the logic for a *single* logical trajectory collection step based on the input `item`.
* **Return value**: It returns a tuple containing:\
1. The collected data for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.\
1. The ScoredDataItem for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.
2. A list of new `Item` objects to be added to the backlog for future processing (e.g., follow-up prompts).\
* **Should I define `collect_trajectory` or override `collect_trajectories`?** If you've got some way to generate your group more efficiently than a bunch of separate but parallel calls to `collect_trajectory`, or if your rollouts aren't independent as in MCTS, you should override `collect_trajectories`. If simplicity and iteration speed is more valuable than efficiency (e.g. at the start of a development cycle) and your rollouts are independent then `collect_trajectory` is for you.

View file

@ -61,6 +61,17 @@ class ScoredDataGroup(TypedDict):
overrides: Optional[List[Dict]]
class ScoredDataItem(TypedDict):
tokens: List[int]
masks: List[int]
scores: float
advantages: Optional[List[float]]
ref_logprobs: Optional[List[float]]
messages: Optional[List[Message]]
group_overrides: Optional[Dict]
overrides: Optional[Dict]
class EvalHandlingEnum(Enum):
"""
Enum for handling evals.
@ -229,7 +240,9 @@ class BaseEnv(ABC):
"""
return cls.env_config_cls(), ServerBaseline()
async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]:
async def collect_trajectory(
self, item: Item
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
raise NotImplementedError(
"Handle env single method must be implemented in subclass "
)
@ -249,13 +262,38 @@ class BaseEnv(ABC):
for _ in range(self.config.group_size):
tasks.append(self.collect_trajectory(item))
results = await asyncio.gather(*tasks)
if any(not isinstance(result[0], dict) for result in results):
logging.error("something wasn't a ScoredDataItem")
raise ValueError(
"collect_trajectory must return a ScoredDataItem or None to use the default "
"collect_trajectories method"
)
backlog = []
to_postprocess = []
to_postprocess = ScoredDataGroup()
to_postprocess["tokens"] = []
to_postprocess["masks"] = []
to_postprocess["scores"] = []
to_postprocess["advantages"] = []
to_postprocess["ref_logprobs"] = []
to_postprocess["messages"] = []
to_postprocess["group_overrides"] = {}
to_postprocess["overrides"] = []
print("Processing results")
for result in results:
if result[0] is not None:
to_postprocess.append(result[0])
to_postprocess["tokens"].append(result[0]["tokens"])
to_postprocess["masks"].append(result[0]["masks"])
to_postprocess["scores"].append(result[0]["scores"])
if result[0].get("advantages", None) is not None:
to_postprocess["advantages"].append(result[0]["advantages"])
if result[0].get("ref_logprobs", None) is not None:
to_postprocess["ref_logprobs"].append(result[0]["ref_logprobs"])
if result[0].get("messages", None) is not None:
to_postprocess["messages"].append(result[0]["messages"])
if result[0].get("group_overrides", None) is not None:
to_postprocess["group_overrides"].update(result[0]["group_overrides"])
if result[0].get("overrides", None) is not None:
to_postprocess["overrides"].append(result[0]["overrides"])
backlog.extend(result[1])
random.shuffle(backlog)
return to_postprocess, backlog
async def postprocess_histories(

View file

@ -4,7 +4,17 @@ import logging
from typing import Any, List, Optional, Union
import scipy
import torch
try:
import torch
except ImportError as e:
logger = logging.getLogger(__name__)
logger.warning(
"torch not installed, please install atroposlib[rewardfns] to use this reward function"
)
raise e
from transformers import AutoModel, AutoTokenizer
from .registry import registry

View file

@ -1,7 +1,7 @@
import math
import numpy as np
import pytest
import torch
# Adjust the import below if your functions are in a different module.
from atroposlib.utils.advantages import (
@ -23,9 +23,9 @@ def test_allclose_to_first_vector():
"""Test that return_vector=True returns a tensor of booleans."""
values = [1.0, 1.000000001, 1.000000002]
result = allclose_to_first(values, return_vector=True)
assert isinstance(result, torch.Tensor)
assert isinstance(result, np.ndarray)
# All comparisons should be True.
assert torch.all(result)
assert np.all(result)
def test_allclose_to_first_not_close():
@ -74,15 +74,15 @@ def test_compute_stats_jagged():
def test_compute_discounted_returns():
"""Test compute_discounted_returns with a tensor input."""
rewards = torch.tensor([1.0, 1.0, 1.0])
rewards = np.array([1.0, 1.0, 1.0])
gamma = 0.9
returns = compute_discounted_returns(rewards, gamma)
# For a 3-element vector:
# t=2: 1.0
# t=1: 1.0 + 0.9*1.0 = 1.9
# t=0: 1.0 + 0.9*1.9 = 2.71
expected = torch.tensor([2.71, 1.9, 1.0])
assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8)
expected = np.array([2.71, 1.9, 1.0])
assert np.allclose(returns, expected, rtol=1e-5, atol=1e-8)
def test_compute_discounted_returns_list_input():
@ -90,8 +90,8 @@ def test_compute_discounted_returns_list_input():
rewards = [1, 1, 1]
gamma = 0.0 # With gamma=0, the returns should equal the rewards.
returns = compute_discounted_returns(rewards, gamma)
expected = torch.tensor([1.0, 1.0, 1.0])
assert torch.allclose(returns, expected, rtol=1e-5, atol=1e-8)
expected = np.array([1.0, 1.0, 1.0])
assert np.allclose(returns, expected, rtol=1e-5, atol=1e-8)
def test_compute_grpo_process_supervision_advantages_cumsum():

View file

@ -1,31 +1,32 @@
from typing import Sequence
import torch
import numpy as np
from atroposlib.type_definitions import number
TensorLike = torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]
NumpyArrayLike = np.ndarray | Sequence[np.ndarray] | Sequence[Sequence]
# Type alias for vector of bools
BoolVector = torch.Tensor
BoolVector = np.ndarray
def allclose_to_first(
values: TensorLike,
# values: TensorLike,
values: NumpyArrayLike,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
return_vector: bool = False,
) -> BoolVector | bool:
"""
Check if all tensors in `values` are close to the first tensor `values[0]` using a vectorized approach.
Check if all arrays in `values` are close to the first array `values[0]` using a vectorized approach.
If `return_vector` is False (default), returns a single boolean indicating whether
every tensor is close to the first tensor. If `return_vector` is True, returns a list
of booleans where each element corresponds to whether the respective tensor in
`values` is close to the first tensor. The first element is always True.
every array is close to the first array. If `return_vector` is True, returns a list
of booleans where each element corresponds to whether the respective array in
`values` is close to the first array. The first element is always True.
Args:
values (torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]):
values (np.ndarray | Sequence[np.ndarray] | Sequence[Sequence]):
Nested list of values to compare. Must be rectangular, but not necessarily 2D.
rtol (float, optional): Relative tolerance. Defaults to 1e-05.
atol (float, optional): Absolute tolerance. Defaults to 1e-08.
@ -35,24 +36,22 @@ def allclose_to_first(
Returns:
bool or BoolVector:
- If `return_vector` is False, returns True if all tensors are close to the first tensor;
- If `return_vector` is False, returns True if all arrays are close to the first array;
otherwise, returns False.
- If `return_vector` is True, returns a 1D tensor of bools where the first element is True
(as the reference tensor is trivially close to itself), and each subsequent element indicates
whether the corresponding tensor is close to the first tensor.
- If `return_vector` is True, returns a 1D array of bools where the first element is True
(as the reference array is trivially close to itself), and each subsequent element indicates
whether the corresponding array is close to the first array.
"""
if not isinstance(values, torch.Tensor):
values = torch.tensor(values)
if not isinstance(values, np.ndarray):
values = np.array(values)
reference = values[0]
is_close = torch.isclose(
values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan
)
is_close = np.isclose(values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan)
# flatten dimensions after first
result_vector = torch.all(is_close.view(is_close.size(0), -1), dim=1)
result_vector = np.all(is_close.reshape(is_close.shape[0], -1), axis=1)
return result_vector if return_vector else bool(torch.all(result_vector))
return result_vector if return_vector else bool(np.all(result_vector))
def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]:
@ -104,23 +103,23 @@ def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]:
return {"mean": mean, "var": variance}
def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Tensor:
def compute_discounted_returns(rewards: np.ndarray, gamma: float) -> np.ndarray:
"""Compute discounted returns from a 1D vector of rewards.
Given a list or torch tensor of rewards and a discount factor, this function computes
Given a list or numpy array of rewards and a discount factor, this function computes
the discounted return at each timestep. The discounted return at time t is defined as:
G_t = rewards[t] + gamma * rewards[t+1] + gamma^2 * rewards[t+2] + ...
Args:
rewards (list[float] or torch.Tensor): A 1D list or tensor of rewards.
rewards (list[float] or np.ndarray): A 1D list or array of rewards.
gamma (float): The discount factor (should be between 0 and 1).
Returns:
list[float]: A list containing the discounted returns for each timestep.
"""
if not isinstance(rewards, torch.Tensor):
rewards = torch.tensor(rewards, dtype=torch.float)
discounted_returns = torch.empty_like(rewards)
if not isinstance(rewards, np.ndarray):
rewards = np.array(rewards, dtype=np.float32) # Use float32 for numpy default
discounted_returns = np.empty_like(rewards)
running_return = 0.0
for t in reversed(range(len(rewards))):
@ -132,7 +131,7 @@ def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Ten
def compute_grpo_process_supervision_advantages(
rewards: Sequence[Sequence[number]], gamma: float = None, std_tol: float = 1e-8
) -> list[torch.Tensor]:
) -> list[np.ndarray]:
"""
Given a (possibly jagged) list of list of rewards, compute advantages for GRPO.
@ -144,7 +143,7 @@ def compute_grpo_process_supervision_advantages(
std_tol (float): The tolerance for the standard deviation.
Returns:
A list of tensors of advantages.
A list of arrays of advantages.
Raises:
ValueError: If the standard deviation of the flattened rewards is smaller than the tolerance.
@ -155,13 +154,11 @@ def compute_grpo_process_supervision_advantages(
if std < std_tol:
raise ValueError(f"`std` is smaller than tolerance of {std_tol}.")
normalized_rewards = [
(torch.tensor(trajectory) - mean) / std for trajectory in rewards
]
normalized_rewards = [(np.array(trajectory) - mean) / std for trajectory in rewards]
if gamma is None:
advantages = [
trajectory.flip(dims=[0]).cumsum(dim=0).flip(dims=[0])
np.flip(np.cumsum(np.flip(trajectory, axis=0), axis=0), axis=0)
for trajectory in normalized_rewards
]
else:

View file

@ -1,4 +1,4 @@
import torch
import numpy as np
from transformers import PreTrainedTokenizer
from atroposlib.type_definitions import Message
@ -39,7 +39,7 @@ def tokenize_for_trainer(
# (e.g. current date). e.g. consider a system prompt that depends on the current date and a run that crosses
# midnight from 3/9 to 3/10 under a tokenizer that tokenizes 3/9 and 3/10 with a different number of tokens.
masks = torch.ones(len(tokens), dtype=torch.long) * -100
masks = np.ones(len(tokens), dtype=np.int64) * -100
for i, msg in enumerate(chat):
if msg["role"] in UNMASKED_ROLES:
@ -51,7 +51,7 @@ def tokenize_for_trainer(
)
start_idx = len(prefix_tokens)
end_idx = len(unmasked_tokens)
masks[start_idx:end_idx] = torch.tensor(unmasked_tokens[start_idx:])
masks[start_idx:end_idx] = np.array(unmasked_tokens[start_idx:])
masks = masks.tolist()
if finish_reason == "length":