mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-22 16:48:57 +00:00
Add additional completions table info: metric, magnitude, and direction for ground truth
This commit is contained in:
parent
c3b80832e9
commit
d2dbab7d22
1 changed files with 60 additions and 0 deletions
|
|
@ -4,11 +4,13 @@ from typing import List, Optional, Tuple, Union, Dict
|
|||
|
||||
from datasets import load_dataset
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
import wandb
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
Item,
|
||||
OpenaiConfig,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
|
|
@ -536,6 +538,64 @@ class FundamentalPredictionEnv(BaseEnv):
|
|||
|
||||
await super().wandb_log(wandb_metrics)
|
||||
|
||||
async def add_rollouts_for_wandb(
|
||||
self,
|
||||
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||
item: Item = None,
|
||||
):
|
||||
# Initialize rollouts_for_wandb if not exists
|
||||
if not hasattr(self, "rollouts_for_wandb"):
|
||||
self.rollouts_for_wandb = []
|
||||
|
||||
# Get number of examples to keep
|
||||
num_keep = getattr(self.config, "num_rollouts_per_group_for_logging", -1)
|
||||
|
||||
if num_keep == -1:
|
||||
num_keep = self.config.group_size
|
||||
|
||||
# Get fundamental metric from item
|
||||
fundamental_metric = item[3]
|
||||
|
||||
# Add examples to rollouts
|
||||
self.rollouts_for_wandb.append(
|
||||
[
|
||||
(
|
||||
self.tokenizer.decode(scored_data["tokens"][i]),
|
||||
scored_data["scores"][i],
|
||||
item[1], # expected direction (maintained/raised/reduced)
|
||||
item[2], # expected magnitude
|
||||
fundamental_metric, # metric type being predicted
|
||||
)
|
||||
for i in range(min(num_keep, len(scored_data["tokens"])))
|
||||
]
|
||||
)
|
||||
|
||||
# Keep buffer size limited
|
||||
max_rollouts = getattr(self.config, "num_rollouts_to_keep", 10)
|
||||
if len(self.rollouts_for_wandb) > max_rollouts:
|
||||
self.rollouts_for_wandb.pop(0)
|
||||
|
||||
async def create_rollout_table(self, wandb_metrics):
|
||||
if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0:
|
||||
table = wandb.Table(columns=[
|
||||
"text",
|
||||
"score",
|
||||
"expected_direction",
|
||||
"expected_magnitude",
|
||||
"fundamental_metric"
|
||||
])
|
||||
|
||||
for group in self.rollouts_for_wandb:
|
||||
for item in group:
|
||||
table.add_data(item[0], item[1], item[2], item[3], item[4])
|
||||
|
||||
wandb_metrics["train/rollouts"] = table
|
||||
|
||||
# Clear rollouts after logging
|
||||
self.rollouts_for_wandb = []
|
||||
|
||||
return wandb_metrics
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
FundamentalPredictionEnv.cli()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue