diff --git a/environments/fundamental_prediction_environment.py b/environments/fundamental_prediction_environment.py index a38c561c..4b283b03 100644 --- a/environments/fundamental_prediction_environment.py +++ b/environments/fundamental_prediction_environment.py @@ -4,11 +4,13 @@ from typing import List, Optional, Tuple, Union, Dict from datasets import load_dataset from tqdm.asyncio import tqdm_asyncio +import wandb from atroposlib.envs.base import ( BaseEnv, BaseEnvConfig, EvalHandlingEnum, + Item, OpenaiConfig, ScoredDataGroup, ) @@ -536,6 +538,64 @@ class FundamentalPredictionEnv(BaseEnv): await super().wandb_log(wandb_metrics) + async def add_rollouts_for_wandb( + self, + scored_data: Union[ScoredDataGroup, List[ScoredDataGroup]], + item: Item = None, + ): + # Initialize rollouts_for_wandb if not exists + if not hasattr(self, "rollouts_for_wandb"): + self.rollouts_for_wandb = [] + + # Get number of examples to keep + num_keep = getattr(self.config, "num_rollouts_per_group_for_logging", -1) + + if num_keep == -1: + num_keep = self.config.group_size + + # Get fundamental metric from item + fundamental_metric = item[3] + + # Add examples to rollouts + self.rollouts_for_wandb.append( + [ + ( + self.tokenizer.decode(scored_data["tokens"][i]), + scored_data["scores"][i], + item[1], # expected direction (maintained/raised/reduced) + item[2], # expected magnitude + fundamental_metric, # metric type being predicted + ) + for i in range(min(num_keep, len(scored_data["tokens"]))) + ] + ) + + # Keep buffer size limited + max_rollouts = getattr(self.config, "num_rollouts_to_keep", 10) + if len(self.rollouts_for_wandb) > max_rollouts: + self.rollouts_for_wandb.pop(0) + + async def create_rollout_table(self, wandb_metrics): + if hasattr(self, "rollouts_for_wandb") and len(self.rollouts_for_wandb) > 0: + table = wandb.Table(columns=[ + "text", + "score", + "expected_direction", + "expected_magnitude", + "fundamental_metric" + ]) + + for group in self.rollouts_for_wandb: + for item in group: + table.add_data(item[0], item[1], item[2], item[3], item[4]) + + wandb_metrics["train/rollouts"] = table + + # Clear rollouts after logging + self.rollouts_for_wandb = [] + + return wandb_metrics + if __name__ == "__main__": FundamentalPredictionEnv.cli()