import asyncio import logging from typing import Any, Dict, List, Optional, Tuple, Union from datasets import load_dataset from pydantic import Field from atroposlib.envs.base import BaseEnv, BaseEnvConfig, ScoredDataGroup from atroposlib.envs.reward_fns import registry from atroposlib.envs.reward_fns.combined_reward import CombinedReward from atroposlib.type_definitions import Item from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) class DatasetEnvConfig(BaseEnvConfig): dataset_name: str = Field(..., description="HuggingFace dataset name") dataset_config: Optional[str] = Field( None, description="Dataset configuration name" ) split: str = Field("train", description="Dataset split to use") dataset_path: Optional[str] = Field( None, description="Local path to dataset (alternative to dataset_name)" ) prompt_field: str = Field(..., description="Field in dataset to use as prompt") answer_field: Optional[str] = Field( None, description="Field in dataset to use as answer" ) ground_truth_field: Optional[str] = Field( None, description="Field in dataset containing canonical correct answer" ) system_prompt: Optional[str] = Field(None, description="System prompt to use") prefill: Optional[str] = Field(None, description="Text to prefill the completion with (e.g. '')") shuffle_dataset: bool = Field(True, description="Whether to shuffle the dataset") max_generations_per_prompt: int = Field( 1, description="Number of generations per prompt for collection" ) include_messages_in_scoring: bool = Field( False, description="Whether to include messages in scoring" ) reward_funcs: List[str] = Field( default_factory=list, description="List of reward function names to apply (legacy)", ) reward_functions: List[Union[str, Dict[str, Any]]] = Field( default_factory=list, description="List of reward functions to apply (string names or full configs)", ) # Completion parameters temperature: float = Field(0.7, description="Temperature for generation") top_p: float = Field(0.9, description="Top-p for generation") max_tokens: int = Field(4096, description="Maximum tokens for generation") length_warmup_steps: int = Field(0, description="Steps for length warmup") min_tokens: int = Field(0, description="Minimum tokens for generation") eval_dataset_name: Optional[str] = Field( None, description="Evaluation dataset name" ) eval_dataset_config: Optional[str] = Field( None, description="Evaluation dataset config" ) eval_split: Optional[str] = Field(None, description="Evaluation dataset split") class DatasetEnv(BaseEnv): def __init__( self, config: DatasetEnvConfig, server_configs, slurm=True, testing=False ): super().__init__(config, server_configs, slurm, testing) self.config = config self.dataset = None self.iter = 0 self.metric_buffer = {} self.reward_function = self._initialize_reward_function() def _initialize_reward_function(self): if hasattr(self.config, "reward_functions") and self.config.reward_functions: if len(self.config.reward_functions) == 1: return registry.create(self.config.reward_functions[0]) else: return CombinedReward( rewards=self.config.reward_functions, normalization="sum" ) elif hasattr(self.config, "reward_funcs") and self.config.reward_funcs: if len(self.config.reward_funcs) == 1: return registry.create(self.config.reward_funcs[0]) else: return CombinedReward( rewards=self.config.reward_funcs, normalization="none" ) async def setup(self): if self.config.dataset_path: self.dataset = load_dataset( self.config.dataset_path, split=self.config.split ) else: self.dataset = load_dataset( self.config.dataset_name, self.config.dataset_config, split=self.config.split, ) logger.info(f"Dataset features: {self.dataset.features}") logger.info(f"Sample item keys: {list(self.dataset[0].keys())}") logger.info(f"Sample item: {self.dataset[0]}") if self.config.shuffle_dataset: self.dataset = self.dataset.shuffle() self.metric_buffer = {} async def get_next_item(self) -> Item: if not self.dataset: await self.setup() item = self.dataset[self.iter % len(self.dataset)] self.iter += 1 user_msg = {"role": "user", "content": item[self.config.prompt_field]} prompt = tuple([frozenset(user_msg.items())]) answer = None if self.config.answer_field and self.config.answer_field in item: answer = item[self.config.answer_field] ground_truth = None if self.config.ground_truth_field and self.config.ground_truth_field in item: ground_truth = item[self.config.ground_truth_field] return (prompt, answer, ground_truth) async def collect_trajectory(self, item: Item) -> Tuple[List, List]: # Extract user prompt and answer from item user_content = dict(item[0][0])["content"] answer = item[1] if len(item) > 1 else None # Create messages list messages = [] if self.config.system_prompt: messages.append({"role": "system", "content": self.config.system_prompt}) messages.append({"role": "user", "content": user_content}) # Add prefill as assistant message if configured if self.config.prefill: messages.append({"role": "assistant", "content": self.config.prefill}) # Convert messages to a prompt string using the tokenizer prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) # Calculate max tokens for generation (with optional warmup) max_tokens = self.config.max_tokens if self.config.length_warmup_steps > 0: warmup_progress = min(1.0, self.curr_step / self.config.length_warmup_steps) max_tokens = int( self.config.min_tokens + warmup_progress * (self.config.max_tokens - self.config.min_tokens) ) # Generate completion using completions API completions = await self.server.completion( prompt=prompt, n=self.config.max_generations_per_prompt, max_tokens=max_tokens, temperature=self.config.temperature, top_p=self.config.top_p, ) to_score = [] to_backlog = [] # Process completions for completion in completions.choices: # Get the completion text completion_text = completion.text if hasattr(completion, "text") else completion.message.content # Build full message sequence for scoring full_messages = [] if self.config.system_prompt: full_messages.append({"role": "system", "content": self.config.system_prompt}) full_messages.append({"role": "user", "content": user_content}) # Combine prefill with completion if prefill was used response_content = completion_text if self.config.prefill: response_content = self.config.prefill + completion_text full_messages.append({"role": "assistant", "content": response_content}) # Add to scoring list with answer and ground truth to_score.append( (full_messages, answer, item[2] if len(item) > 2 else None) ) return to_score, to_backlog async def postprocess_histories(self, trajectories: List) -> Tuple[List, List]: return trajectories, [] async def collect_trajectories(self, item: Item) -> Tuple[List, List]: self.current_item = item # Extract user prompt from item user_content = dict(item[0][0])["content"] # Create messages list messages = [] if self.config.system_prompt: messages.append({"role": "system", "content": self.config.system_prompt}) messages.append({"role": "user", "content": user_content}) # Add prefill as assistant message if configured if self.config.prefill: messages.append({"role": "assistant", "content": self.config.prefill}) # Convert messages to a prompt string using the tokenizer prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) # Calculate max tokens for generation (with optional warmup) max_tokens = self.config.max_tokens # Generate completions completions = await self.server.completion( prompt=prompt, n=self.config.group_size, max_tokens=max_tokens, temperature=self.config.temperature, top_p=self.config.top_p, ) print(f"Completions: {completions}") # Process completions trajectories = [] for completion in completions.choices: # Get the completion text completion_text = completion.text if hasattr(completion, "text") else completion.message.content # Build complete message sequence full_messages = [] if self.config.system_prompt: full_messages.append({"role": "system", "content": self.config.system_prompt}) full_messages.append({"role": "user", "content": user_content}) # Combine prefill with completion if prefill was used response_content = completion_text if self.config.prefill: response_content = self.config.prefill + completion_text full_messages.append({"role": "assistant", "content": response_content}) trajectories.append(full_messages) return trajectories, [] async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: logger.warning(f"Scoring {len(rollout_group_data)} rollout items") scores = ScoredDataGroup() scores["tokens"] = [] scores["masks"] = [] scores["scores"] = [] scores["advantages"] = None scores["ref_logprobs"] = None scores["messages"] = None if not self.config.include_messages_in_scoring else [] answer = ( self.current_item[1] if self.current_item and len(self.current_item) > 1 else None ) logger.warning(f"Answer for current item: {answer}") ground_truth = ( self.current_item[2] if self.current_item and len(self.current_item) > 2 else None ) logger.warning(f"Ground truth for current item: {ground_truth}") formatted_completions = [] for trajectory in rollout_group_data: if trajectory and isinstance(trajectory, list): assistant_messages = [ msg for msg in trajectory if isinstance(msg, dict) and msg.get("role") == "assistant" ] if assistant_messages: formatted_completions.append([assistant_messages[-1]]) if not formatted_completions: logger.warning("No valid completions to score") return None try: reward_kwargs = { "solution": answer, "ground_truth": ground_truth, "item": self.current_item, "config": self.config, } all_rewards = self.reward_function(formatted_completions, **reward_kwargs) logger.info(f"Calculated rewards: {all_rewards}") except Exception as e: logger.error(f"Error applying reward functions: {e}") logger.exception(e) all_rewards = [0.0] * len(formatted_completions) for i, (trajectory, reward) in enumerate(zip(rollout_group_data, all_rewards)): try: tokenized = tokenize_for_trainer(self.tokenizer, trajectory) scores["tokens"].append(tokenized["tokens"]) scores["masks"].append(tokenized["masks"]) scores["scores"].append(reward) if self.config.include_messages_in_scoring: if "messages" not in scores: scores["messages"] = [] scores["messages"].append(trajectory) logger.warning(f"Scores: {scores['scores']}") except Exception as e: logger.error(f"Error processing trajectory {i}: {e}") logger.exception(e) if not scores["tokens"]: logger.warning("No valid scores generated") return None logger.info(f"Generated scores: {scores['scores']}") return scores async def evaluate(self): if ( not hasattr(self.config, "eval_dataset_name") or not self.config.eval_dataset_name ): return if not hasattr(self, "eval_dataset"): self.eval_dataset = load_dataset( self.config.eval_dataset_name, self.config.eval_dataset_config, split=self.config.eval_split, ) self.eval_dataset = self.eval_dataset.select( range(min(100, len(self.eval_dataset))) ) eval_metrics = {} eval_tasks = [] for i in range(min(self.config.max_eval_workers, len(self.eval_dataset))): item = self.eval_dataset[i] user_msg = {"role": "user", "content": item[self.config.prompt_field]} prompt = tuple([frozenset(user_msg.items())]) answer = None if self.config.answer_field and self.config.answer_field in item: answer = item[self.config.answer_field] eval_tasks.append(self.collect_trajectory((prompt, answer))) eval_results = await asyncio.gather(*eval_tasks) eval_scores = [] for result in eval_results: if result[0]: scored_data = await self.score(result[0]) if scored_data and "scores" in scored_data: eval_scores.extend(scored_data["scores"]) if eval_scores: eval_metrics["eval/mean_score"] = sum(eval_scores) / len(eval_scores) eval_metrics["eval/max_score"] = max(eval_scores) eval_metrics["eval/min_score"] = min(eval_scores) await self.wandb_log(eval_metrics) async def wandb_log(self, wandb_metrics: Optional[Dict] = None): metrics = wandb_metrics or {} for key, values in self.metric_buffer.items(): if values: metrics[f"train/{key}"] = sum(values) / len(values) self.metric_buffer = {k: [] for k in self.metric_buffer} if hasattr(self, "reward_function") and self.wandb: if hasattr(self.reward_function, "set_wandb_logger"): self.reward_function.set_wandb_logger(self.wandb) await super().wandb_log(metrics) if __name__ == "__main__": # Launch the DatasetEnv via the BaseEnv CLI (serve or process) DatasetEnv.cli()