Merge branch 'main' into blackjack2-env

This commit is contained in:
Shannon Sands 2025-05-13 07:54:04 +10:00
commit 36f6822d71
9 changed files with 455 additions and 61 deletions

View file

@ -1,31 +1,32 @@
from typing import Sequence
import torch
import numpy as np
from atroposlib.type_definitions import number
TensorLike = torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]
NumpyArrayLike = np.ndarray | Sequence[np.ndarray] | Sequence[Sequence]
# Type alias for vector of bools
BoolVector = torch.Tensor
BoolVector = np.ndarray
def allclose_to_first(
values: TensorLike,
# values: TensorLike,
values: NumpyArrayLike,
rtol: float = 1e-05,
atol: float = 1e-08,
equal_nan: bool = False,
return_vector: bool = False,
) -> BoolVector | bool:
"""
Check if all tensors in `values` are close to the first tensor `values[0]` using a vectorized approach.
Check if all arrays in `values` are close to the first array `values[0]` using a vectorized approach.
If `return_vector` is False (default), returns a single boolean indicating whether
every tensor is close to the first tensor. If `return_vector` is True, returns a list
of booleans where each element corresponds to whether the respective tensor in
`values` is close to the first tensor. The first element is always True.
every array is close to the first array. If `return_vector` is True, returns a list
of booleans where each element corresponds to whether the respective array in
`values` is close to the first array. The first element is always True.
Args:
values (torch.Tensor | Sequence[torch.Tensor] | Sequence[Sequence]):
values (np.ndarray | Sequence[np.ndarray] | Sequence[Sequence]):
Nested list of values to compare. Must be rectangular, but not necessarily 2D.
rtol (float, optional): Relative tolerance. Defaults to 1e-05.
atol (float, optional): Absolute tolerance. Defaults to 1e-08.
@ -35,24 +36,22 @@ def allclose_to_first(
Returns:
bool or BoolVector:
- If `return_vector` is False, returns True if all tensors are close to the first tensor;
- If `return_vector` is False, returns True if all arrays are close to the first array;
otherwise, returns False.
- If `return_vector` is True, returns a 1D tensor of bools where the first element is True
(as the reference tensor is trivially close to itself), and each subsequent element indicates
whether the corresponding tensor is close to the first tensor.
- If `return_vector` is True, returns a 1D array of bools where the first element is True
(as the reference array is trivially close to itself), and each subsequent element indicates
whether the corresponding array is close to the first array.
"""
if not isinstance(values, torch.Tensor):
values = torch.tensor(values)
if not isinstance(values, np.ndarray):
values = np.array(values)
reference = values[0]
is_close = torch.isclose(
values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan
)
is_close = np.isclose(values, reference, rtol=rtol, atol=atol, equal_nan=equal_nan)
# flatten dimensions after first
result_vector = torch.all(is_close.view(is_close.size(0), -1), dim=1)
result_vector = np.all(is_close.reshape(is_close.shape[0], -1), axis=1)
return result_vector if return_vector else bool(torch.all(result_vector))
return result_vector if return_vector else bool(np.all(result_vector))
def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]:
@ -104,23 +103,23 @@ def compute_stats(data: Sequence[number | Sequence]) -> dict[str, float]:
return {"mean": mean, "var": variance}
def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Tensor:
def compute_discounted_returns(rewards: np.ndarray, gamma: float) -> np.ndarray:
"""Compute discounted returns from a 1D vector of rewards.
Given a list or torch tensor of rewards and a discount factor, this function computes
Given a list or numpy array of rewards and a discount factor, this function computes
the discounted return at each timestep. The discounted return at time t is defined as:
G_t = rewards[t] + gamma * rewards[t+1] + gamma^2 * rewards[t+2] + ...
Args:
rewards (list[float] or torch.Tensor): A 1D list or tensor of rewards.
rewards (list[float] or np.ndarray): A 1D list or array of rewards.
gamma (float): The discount factor (should be between 0 and 1).
Returns:
list[float]: A list containing the discounted returns for each timestep.
"""
if not isinstance(rewards, torch.Tensor):
rewards = torch.tensor(rewards, dtype=torch.float)
discounted_returns = torch.empty_like(rewards)
if not isinstance(rewards, np.ndarray):
rewards = np.array(rewards, dtype=np.float32) # Use float32 for numpy default
discounted_returns = np.empty_like(rewards)
running_return = 0.0
for t in reversed(range(len(rewards))):
@ -132,7 +131,7 @@ def compute_discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Ten
def compute_grpo_process_supervision_advantages(
rewards: Sequence[Sequence[number]], gamma: float = None, std_tol: float = 1e-8
) -> list[torch.Tensor]:
) -> list[np.ndarray]:
"""
Given a (possibly jagged) list of list of rewards, compute advantages for GRPO.
@ -144,7 +143,7 @@ def compute_grpo_process_supervision_advantages(
std_tol (float): The tolerance for the standard deviation.
Returns:
A list of tensors of advantages.
A list of arrays of advantages.
Raises:
ValueError: If the standard deviation of the flattened rewards is smaller than the tolerance.
@ -155,13 +154,11 @@ def compute_grpo_process_supervision_advantages(
if std < std_tol:
raise ValueError(f"`std` is smaller than tolerance of {std_tol}.")
normalized_rewards = [
(torch.tensor(trajectory) - mean) / std for trajectory in rewards
]
normalized_rewards = [(np.array(trajectory) - mean) / std for trajectory in rewards]
if gamma is None:
advantages = [
trajectory.flip(dims=[0]).cumsum(dim=0).flip(dims=[0])
np.flip(np.cumsum(np.flip(trajectory, axis=0), axis=0), axis=0)
for trajectory in normalized_rewards
]
else:

View file

@ -1,4 +1,4 @@
import torch
import numpy as np
from transformers import PreTrainedTokenizer
from atroposlib.type_definitions import Message
@ -39,7 +39,7 @@ def tokenize_for_trainer(
# (e.g. current date). e.g. consider a system prompt that depends on the current date and a run that crosses
# midnight from 3/9 to 3/10 under a tokenizer that tokenizes 3/9 and 3/10 with a different number of tokens.
masks = torch.ones(len(tokens), dtype=torch.long) * -100
masks = np.ones(len(tokens), dtype=np.int64) * -100
for i, msg in enumerate(chat):
if msg["role"] in UNMASKED_ROLES:
@ -51,7 +51,7 @@ def tokenize_for_trainer(
)
start_idx = len(prefix_tokens)
end_idx = len(unmasked_tokens)
masks[start_idx:end_idx] = torch.tensor(unmasked_tokens[start_idx:])
masks[start_idx:end_idx] = np.array(unmasked_tokens[start_idx:])
masks = masks.tolist()
if finish_reason == "length":