Add additional completions table info: metric, magnitude, and direction for ground truth

This commit is contained in:
teknium1 2025-05-04 03:30:50 -07:00
parent c3b80832e9
commit d2dbab7d22

View file

@ -4,11 +4,13 @@ from typing import List, Optional, Tuple, Union, Dict
from datasets import load_dataset
from tqdm.asyncio import tqdm_asyncio
import wandb
from atroposlib.envs.base import (
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
Item,
OpenaiConfig,
ScoredDataGroup,
)
@ -536,6 +538,64 @@ class FundamentalPredictionEnv(BaseEnv):
await super().wandb_log(wandb_metrics)
async def add_rollouts_for_wandb(
self,
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
item: Item = None,
):
# Initialize rollouts_for_wandb if not exists
if not hasattr(self, "rollouts_for_wandb"):
self.rollouts_for_wandb = []
# Get number of examples to keep
num_keep = getattr(self.config, "num_rollouts_per_group_for_logging", -1)
if num_keep == -1:
num_keep = self.config.group_size
# Get fundamental metric from item
fundamental_metric = item[3]
# Add examples to rollouts
self.rollouts_for_wandb.append(
[
(
self.tokenizer.decode(scored_data["tokens"][i]),
scored_data["scores"][i],
item[1], # expected direction (maintained/raised/reduced)
item[2], # expected magnitude
fundamental_metric, # metric type being predicted
)
for i in range(min(num_keep, len(scored_data["tokens"])))
]
)
# Keep buffer size limited
max_rollouts = getattr(self.config, "num_rollouts_to_keep", 10)
if len(self.rollouts_for_wandb) > max_rollouts:
self.rollouts_for_wandb.pop(0)
async def create_rollout_table(self, wandb_metrics):
if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0:
table = wandb.Table(columns=[
"text",
"score",
"expected_direction",
"expected_magnitude",
"fundamental_metric"
])
for group in self.rollouts_for_wandb:
for item in group:
table.add_data(item[0], item[1], item[2], item[3], item[4])
wandb_metrics["train/rollouts"] = table
# Clear rollouts after logging
self.rollouts_for_wandb = []
return wandb_metrics
if __name__ == "__main__":
FundamentalPredictionEnv.cli()