Merge branch 'main' into blackjack2-env

This commit is contained in:
Shannon Sands 2025-05-13 07:54:04 +10:00
commit 36f6822d71
9 changed files with 455 additions and 61 deletions

View file

@ -27,9 +27,9 @@ These methods **must** be implemented in your subclass:
* **`async def collect_trajectories(self, item: Item) -> Tuple[Union[Optional[ScoredDataGroup], List[Optional[ScoredDataGroup]], List[Any | None]], List[Item]]`**: The default implementation of this method runs `collect_trajectory` (see below) multiple times in parallel (controlled by `group_size`). You can override this if you have a more efficient way to generate the entire group of responses/trajectories at once based on the input `item` (e.g. the `n` parameter in the OpenAI chat completions API) or some desired coupling of rollouts (e.g. via MCTS). It should return the collected group data and a list of backlog items.
* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]`**: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the `collect_trajectory` method and use the default implementation of `collect_trajectories` which runs `group_size` instances of `collect_trajectory` in parallel. This method defines the logic for a *single* logical trajectory collection step based on the input `item`.
* **`async def collect_trajectory(self, item: Item) -> Tuple[Any | ScoredDataItem | None, List[Item]]`**: If the rollouts for your environment can be sampled independently, the easiest way to implement GRPO-style grouping is to define the `collect_trajectory` method and use the default implementation of `collect_trajectories` which runs `group_size` instances of `collect_trajectory` in parallel. This method defines the logic for a *single* logical trajectory collection step based on the input `item`.
* **Return value**: It returns a tuple containing:\
1. The collected data for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.\
1. The ScoredDataItem for this step (one trajectory). This data can be processed further in `postprocess_histories`, if you require additional filtering right before sending to the API.
2. A list of new `Item` objects to be added to the backlog for future processing (e.g., follow-up prompts).\
* **Should I define `collect_trajectory` or override `collect_trajectories`?** If you've got some way to generate your group more efficiently than a bunch of separate but parallel calls to `collect_trajectory`, or if your rollouts aren't independent as in MCTS, you should override `collect_trajectories`. If simplicity and iteration speed is more valuable than efficiency (e.g. at the start of a development cycle) and your rollouts are independent then `collect_trajectory` is for you.

View file

@ -61,6 +61,17 @@ class ScoredDataGroup(TypedDict):
overrides: Optional[List[Dict]]
class ScoredDataItem(TypedDict):
tokens: List[int]
masks: List[int]
scores: float
advantages: Optional[List[float]]
ref_logprobs: Optional[List[float]]
messages: Optional[List[Message]]
group_overrides: Optional[Dict]
overrides: Optional[Dict]
class EvalHandlingEnum(Enum):
"""
Enum for handling evals.
@ -229,7 +240,9 @@ class BaseEnv(ABC):
"""
return cls.env_config_cls(), ServerBaseline()
async def collect_trajectory(self, item: Item) -> Tuple[Any | None, List[Item]]:
async def collect_trajectory(
self, item: Item
) -> Tuple[Optional[Union[ScoredDataItem, Any]], List[Item]]:
raise NotImplementedError(
"Handle env single method must be implemented in subclass "
)
@ -249,13 +262,38 @@ class BaseEnv(ABC):
for _ in range(self.config.group_size):
tasks.append(self.collect_trajectory(item))
results = await asyncio.gather(*tasks)
if any(not isinstance(result[0], dict) for result in results):
logging.error("something wasn't a ScoredDataItem")
raise ValueError(
"collect_trajectory must return a ScoredDataItem or None to use the default "
"collect_trajectories method"
)
backlog = []
to_postprocess = []
to_postprocess = ScoredDataGroup()
to_postprocess["tokens"] = []
to_postprocess["masks"] = []
to_postprocess["scores"] = []
to_postprocess["advantages"] = []
to_postprocess["ref_logprobs"] = []
to_postprocess["messages"] = []
to_postprocess["group_overrides"] = {}
to_postprocess["overrides"] = []
print("Processing results")
for result in results:
if result[0] is not None:
to_postprocess.append(result[0])
to_postprocess["tokens"].append(result[0]["tokens"])
to_postprocess["masks"].append(result[0]["masks"])
to_postprocess["scores"].append(result[0]["scores"])
if result[0].get("advantages", None) is not None:
to_postprocess["advantages"].append(result[0]["advantages"])
if result[0].get("ref_logprobs", None) is not None:
to_postprocess["ref_logprobs"].append(result[0]["ref_logprobs"])
if result[0].get("messages", None) is not None:
to_postprocess["messages"].append(result[0]["messages"])
if result[0].get("group_overrides", None) is not None:
to_postprocess["group_overrides"].update(result[0]["group_overrides"])
if result[0].get("overrides", None) is not None:
to_postprocess["overrides"].append(result[0]["overrides"])
backlog.extend(result[1])
random.shuffle(backlog)
return to_postprocess, backlog
async def postprocess_histories(

View file

@ -4,7 +4,17 @@ import logging
from typing import Any, List, Optional, Union
import scipy
import torch
try:
import torch
except ImportError as e:
logger = logging.getLogger(__name__)
logger.warning(
"torch not installed, please install atroposlib[rewardfns] to use this reward function"
)
raise e
from transformers import AutoModel, AutoTokenizer
from .registry import registry