mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-26 17:13:09 +00:00
Add additional completions table info: metric, magnitude, and direction for ground truth
This commit is contained in:
parent
c3b80832e9
commit
d2dbab7d22
1 changed files with 60 additions and 0 deletions
|
|
@ -4,11 +4,13 @@ from typing import List, Optional, Tuple, Union, Dict
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from tqdm.asyncio import tqdm_asyncio
|
from tqdm.asyncio import tqdm_asyncio
|
||||||
|
import wandb
|
||||||
|
|
||||||
from atroposlib.envs.base import (
|
from atroposlib.envs.base import (
|
||||||
BaseEnv,
|
BaseEnv,
|
||||||
BaseEnvConfig,
|
BaseEnvConfig,
|
||||||
EvalHandlingEnum,
|
EvalHandlingEnum,
|
||||||
|
Item,
|
||||||
OpenaiConfig,
|
OpenaiConfig,
|
||||||
ScoredDataGroup,
|
ScoredDataGroup,
|
||||||
)
|
)
|
||||||
|
|
@ -536,6 +538,64 @@ class FundamentalPredictionEnv(BaseEnv):
|
||||||
|
|
||||||
await super().wandb_log(wandb_metrics)
|
await super().wandb_log(wandb_metrics)
|
||||||
|
|
||||||
|
async def add_rollouts_for_wandb(
|
||||||
|
self,
|
||||||
|
scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]],
|
||||||
|
item: Item = None,
|
||||||
|
):
|
||||||
|
# Initialize rollouts_for_wandb if not exists
|
||||||
|
if not hasattr(self, "rollouts_for_wandb"):
|
||||||
|
self.rollouts_for_wandb = []
|
||||||
|
|
||||||
|
# Get number of examples to keep
|
||||||
|
num_keep = getattr(self.config, "num_rollouts_per_group_for_logging", -1)
|
||||||
|
|
||||||
|
if num_keep == -1:
|
||||||
|
num_keep = self.config.group_size
|
||||||
|
|
||||||
|
# Get fundamental metric from item
|
||||||
|
fundamental_metric = item[3]
|
||||||
|
|
||||||
|
# Add examples to rollouts
|
||||||
|
self.rollouts_for_wandb.append(
|
||||||
|
[
|
||||||
|
(
|
||||||
|
self.tokenizer.decode(scored_data["tokens"][i]),
|
||||||
|
scored_data["scores"][i],
|
||||||
|
item[1], # expected direction (maintained/raised/reduced)
|
||||||
|
item[2], # expected magnitude
|
||||||
|
fundamental_metric, # metric type being predicted
|
||||||
|
)
|
||||||
|
for i in range(min(num_keep, len(scored_data["tokens"])))
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep buffer size limited
|
||||||
|
max_rollouts = getattr(self.config, "num_rollouts_to_keep", 10)
|
||||||
|
if len(self.rollouts_for_wandb) > max_rollouts:
|
||||||
|
self.rollouts_for_wandb.pop(0)
|
||||||
|
|
||||||
|
async def create_rollout_table(self, wandb_metrics):
|
||||||
|
if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0:
|
||||||
|
table = wandb.Table(columns=[
|
||||||
|
"text",
|
||||||
|
"score",
|
||||||
|
"expected_direction",
|
||||||
|
"expected_magnitude",
|
||||||
|
"fundamental_metric"
|
||||||
|
])
|
||||||
|
|
||||||
|
for group in self.rollouts_for_wandb:
|
||||||
|
for item in group:
|
||||||
|
table.add_data(item[0], item[1], item[2], item[3], item[4])
|
||||||
|
|
||||||
|
wandb_metrics["train/rollouts"] = table
|
||||||
|
|
||||||
|
# Clear rollouts after logging
|
||||||
|
self.rollouts_for_wandb = []
|
||||||
|
|
||||||
|
return wandb_metrics
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
FundamentalPredictionEnv.cli()
|
FundamentalPredictionEnv.cli()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue