diff --git a/environments/pairwise_judgement_environment.py b/environments/pairwise_judgement_environment.py index 4e76d56f..e970f62d 100644 --- a/environments/pairwise_judgement_environment.py +++ b/environments/pairwise_judgement_environment.py @@ -398,13 +398,16 @@ class PairwiseJudgementEnv(BaseEnv): """Set up the environment by loading datasets.""" # Load training dataset try: - self.train = self._load_dataset(self.config.train_dataset, self.config.train_split) + self.train = self._load_dataset( + self.config.train_dataset, self.config.train_split + ) # Shuffle training dataset for reproducibility - if hasattr(self.train, 'shuffle'): + if hasattr(self.train, "shuffle"): self.train = self.train.shuffle(seed=42) else: # For list-like objects, convert to list and shuffle import random + train_list = list(self.train) random.seed(42) random.shuffle(train_list) @@ -413,42 +416,60 @@ class PairwiseJudgementEnv(BaseEnv): print(f"Error loading training dataset '{self.config.train_dataset}': {e}") # Create minimal fallback data in expected format self.train = [ - {"uid": "train_1", "category": "general", "prompt": "What is the capital of France?"}, - {"uid": "train_2", "category": "math", "prompt": "Solve for x: 2x + 5 = 15"}, - {"uid": "train_3", "category": "coding", "prompt": "Write a Python function to calculate factorial"}, + { + "uid": "train_1", + "category": "general", + "prompt": "What is the capital of France?", + }, + { + "uid": "train_2", + "category": "math", + "prompt": "Solve for x: 2x + 5 = 15", + }, + { + "uid": "train_3", + "category": "coding", + "prompt": "Write a Python function to calculate factorial", + }, ] * 34 # 102 examples print(f"Using fallback training data with {len(self.train)} examples") # Load evaluation dataset try: - self.test = self._load_dataset(self.config.eval_dataset, self.config.eval_split) + self.test = self._load_dataset( + self.config.eval_dataset, self.config.eval_split + ) except Exception as e: print(f"Error loading evaluation dataset '{self.config.eval_dataset}': {e}") raise # Evaluation dataset must work # Analyze training dataset composition - if hasattr(self.train, '__iter__'): + if hasattr(self.train, "__iter__"): train_category_counts = {} total_train_items = 0 - + for item in self.train: total_train_items += 1 category = item.get("category", "Unknown") - train_category_counts[category] = train_category_counts.get(category, 0) + 1 + train_category_counts[category] = ( + train_category_counts.get(category, 0) + 1 + ) print(f"\nTraining dataset analysis ({total_train_items} total items):") for category, count in sorted(train_category_counts.items()): print(f" - {category}: {count} samples") # Analyze evaluation dataset composition - if hasattr(self.test, '__iter__'): + if hasattr(self.test, "__iter__"): eval_category_counts = {} total_eval_items = 0 - + for item in self.test: total_eval_items += 1 category = item.get("subset", "Unknown") - eval_category_counts[category] = eval_category_counts.get(category, 0) + 1 + eval_category_counts[category] = ( + eval_category_counts.get(category, 0) + 1 + ) print(f"\nEvaluation dataset analysis ({total_eval_items} total items):") for category, count in sorted(eval_category_counts.items()): @@ -465,7 +486,9 @@ class PairwiseJudgementEnv(BaseEnv): if self.config.eval_categories is not None: selected_categories = [] for cat in self.config.eval_categories: - selected_categories.append(cat.value if hasattr(cat, 'value') else str(cat)) + selected_categories.append( + cat.value if hasattr(cat, "value") else str(cat) + ) print( f"\nCategory filtering enabled. Selected categories: {selected_categories}" ) @@ -480,8 +503,12 @@ class PairwiseJudgementEnv(BaseEnv): # Show configuration info print(f"\nPairwise Judgement Configuration:") - print(f" - Training dataset: {self.config.train_dataset} (split: {self.config.train_split})") - print(f" - Evaluation dataset: {self.config.eval_dataset} (split: {self.config.eval_split})") + print( + f" - Training dataset: {self.config.train_dataset} (split: {self.config.train_split})" + ) + print( + f" - Evaluation dataset: {self.config.eval_dataset} (split: {self.config.eval_split})" + ) print(f" - Thinking mode: {self.config.thinking_mode}") print(f" - Eval temperature: {self.config.eval_temperature}") print(f" - Number of choices: {self.config.num_choices}") @@ -492,7 +519,7 @@ class PairwiseJudgementEnv(BaseEnv): sample_train_item = self.train[0] print(f"\nSample training item structure:") print(f"- Available keys: {list(sample_train_item.keys())}") - + if "uid" in sample_train_item: print(f"- UID: {sample_train_item['uid']}") if "category" in sample_train_item: @@ -515,12 +542,16 @@ class PairwiseJudgementEnv(BaseEnv): # Handle different dataset structures if "prompt" in sample_eval_item: print(f"- Prompt: {sample_eval_item['prompt'][:100]}...") - elif "chosen" in sample_eval_item and isinstance(sample_eval_item["chosen"], str): + elif "chosen" in sample_eval_item and isinstance( + sample_eval_item["chosen"], str + ): print(f"- Chosen (string): {sample_eval_item['chosen'][:100]}...") elif "rejected" in sample_eval_item and isinstance( sample_eval_item["rejected"], str ): - print(f"- Rejected (string): {sample_eval_item['rejected'][:100]}...") + print( + f"- Rejected (string): {sample_eval_item['rejected'][:100]}..." + ) if "chosen" in sample_eval_item: if isinstance(sample_eval_item["chosen"], list): @@ -530,11 +561,15 @@ class PairwiseJudgementEnv(BaseEnv): f"- First chosen (truncated): {sample_eval_item['chosen'][0][:200]}..." ) else: - print(f"- Chosen (string): {sample_eval_item['chosen'][:200]}...") + print( + f"- Chosen (string): {sample_eval_item['chosen'][:200]}..." + ) if "rejected" in sample_eval_item: if isinstance(sample_eval_item["rejected"], list): - print(f"- Rejected responses: {len(sample_eval_item['rejected'])}") + print( + f"- Rejected responses: {len(sample_eval_item['rejected'])}" + ) if sample_eval_item["rejected"]: print( f"- First rejected (truncated): {sample_eval_item['rejected'][0][:200]}..." @@ -545,7 +580,9 @@ class PairwiseJudgementEnv(BaseEnv): ) except Exception as e: - print(f"Warning: Could not display sample evaluation item structure: {e}") + print( + f"Warning: Could not display sample evaluation item structure: {e}" + ) # Show debug mode status if self.config.full_debug: @@ -569,53 +606,81 @@ class PairwiseJudgementEnv(BaseEnv): def _load_dataset(self, dataset_path: str, split: str = None) -> List[Dict]: """ Load dataset using HuggingFace load_dataset (supports both HF datasets and local files). - + Args: dataset_path: Either HuggingFace dataset name or path to local file split: Split to use - + Returns: List of dataset items """ import os - + try: # Check if it's a local file if os.path.exists(dataset_path): # Local file - use appropriate loader based on extension - if dataset_path.endswith('.jsonl') or dataset_path.endswith('.json'): - dataset = load_dataset("json", data_files=dataset_path, split=split or "train", trust_remote_code=True) - elif dataset_path.endswith('.csv'): - dataset = load_dataset("csv", data_files=dataset_path, split=split or "train", trust_remote_code=True) - elif dataset_path.endswith('.parquet'): - dataset = load_dataset("parquet", data_files=dataset_path, split=split or "train", trust_remote_code=True) + if dataset_path.endswith(".jsonl") or dataset_path.endswith(".json"): + dataset = load_dataset( + "json", + data_files=dataset_path, + split=split or "train", + trust_remote_code=True, + ) + elif dataset_path.endswith(".csv"): + dataset = load_dataset( + "csv", + data_files=dataset_path, + split=split or "train", + trust_remote_code=True, + ) + elif dataset_path.endswith(".parquet"): + dataset = load_dataset( + "parquet", + data_files=dataset_path, + split=split or "train", + trust_remote_code=True, + ) else: # Try JSON as default - dataset = load_dataset("json", data_files=dataset_path, split=split or "train", trust_remote_code=True) - - print(f"Loaded local dataset from {dataset_path} with {len(dataset)} examples") - + dataset = load_dataset( + "json", + data_files=dataset_path, + split=split or "train", + trust_remote_code=True, + ) + + print( + f"Loaded local dataset from {dataset_path} with {len(dataset)} examples" + ) + else: # HuggingFace dataset if split: - dataset = load_dataset(dataset_path, split=split, trust_remote_code=True) + dataset = load_dataset( + dataset_path, split=split, trust_remote_code=True + ) else: dataset_dict = load_dataset(dataset_path, trust_remote_code=True) # If no split specified, try to get the first available split - if hasattr(dataset_dict, 'keys'): + if hasattr(dataset_dict, "keys"): available_splits = list(dataset_dict.keys()) if available_splits: dataset = dataset_dict[available_splits[0]] - print(f"No split specified, using '{available_splits[0]}' split") + print( + f"No split specified, using '{available_splits[0]}' split" + ) else: dataset = dataset_dict else: dataset = dataset_dict - - print(f"Loaded HuggingFace dataset {dataset_path} with {len(dataset)} examples") - + + print( + f"Loaded HuggingFace dataset {dataset_path} with {len(dataset)} examples" + ) + return dataset - + except Exception as e: print(f"Error loading dataset {dataset_path}: {e}") raise @@ -711,17 +776,21 @@ class PairwiseJudgementEnv(BaseEnv): # Generate different quality responses to compare answer_generation_prompt = self._create_system_content() - - answer_prompt = tuple([ - frozenset({"role": "system", "content": answer_generation_prompt}.items()), - frozenset({"role": "user", "content": prompt_text}.items()), - ]) + + answer_prompt = tuple( + [ + frozenset( + {"role": "system", "content": answer_generation_prompt}.items() + ), + frozenset({"role": "user", "content": prompt_text}.items()), + ] + ) # Generate multiple responses with different quality levels for comparison try: # Generate responses with different parameters to get varied quality high_temp_messages = self._prepare_completion_input(answer_prompt) - + # High temperature response (more creative/varied, potentially lower quality) high_temp_completion = await self.server.chat_completion( messages=high_temp_messages, @@ -729,7 +798,7 @@ class PairwiseJudgementEnv(BaseEnv): max_tokens=self.config.train_max_tokens // 2, temperature=1.2, ) - + # Low temperature response (more conservative, potentially higher quality) low_temp_completion = await self.server.chat_completion( messages=high_temp_messages, @@ -737,31 +806,36 @@ class PairwiseJudgementEnv(BaseEnv): max_tokens=self.config.train_max_tokens // 2, temperature=0.3, ) - - if (high_temp_completion.choices and low_temp_completion.choices and - high_temp_completion.choices[0].message.content and - low_temp_completion.choices[0].message.content): - + + if ( + high_temp_completion.choices + and low_temp_completion.choices + and high_temp_completion.choices[0].message.content + and low_temp_completion.choices[0].message.content + ): + high_temp_answer = high_temp_completion.choices[0].message.content low_temp_answer = low_temp_completion.choices[0].message.content - + # Create list of answers for comparison answers = [low_temp_answer, high_temp_answer] - + # Pad with generic answers if we need more choices while len(answers) < self.config.num_choices: - answers.append("I don't have enough information to answer this question thoroughly.") - + answers.append( + "I don't have enough information to answer this question thoroughly." + ) + # Take only the number of choices we need - answers = answers[:self.config.num_choices] - + answers = answers[: self.config.num_choices] + # Randomly shuffle positions random.shuffle(answers) - + # Find where the low temp (better) answer ended up correct_index = answers.index(low_temp_answer) correct_answer = self.choice_letters[correct_index] - + else: # Fallback if generation fails answers = [ @@ -770,13 +844,17 @@ class PairwiseJudgementEnv(BaseEnv): ] # Pad to required number of choices while len(answers) < self.config.num_choices: - answers.append("I don't have sufficient information to provide a complete answer.") - - answers = answers[:self.config.num_choices] + answers.append( + "I don't have sufficient information to provide a complete answer." + ) + + answers = answers[: self.config.num_choices] random.shuffle(answers) - correct_index = answers.index("This is a comprehensive and well-structured response that addresses the question thoroughly with detailed examples and clear explanations.") + correct_index = answers.index( + "This is a comprehensive and well-structured response that addresses the question thoroughly with detailed examples and clear explanations." + ) correct_answer = self.choice_letters[correct_index] - + except Exception as e: # Fallback if generation fails print(f"Warning: Failed to generate training responses: {e}") @@ -787,20 +865,24 @@ class PairwiseJudgementEnv(BaseEnv): # Pad to required number of choices while len(answers) < self.config.num_choices: answers.append("Insufficient information provided.") - - answers = answers[:self.config.num_choices] + + answers = answers[: self.config.num_choices] random.shuffle(answers) - correct_index = answers.index("This is a comprehensive and detailed response that properly addresses all aspects of the question.") + correct_index = answers.index( + "This is a comprehensive and detailed response that properly addresses all aspects of the question." + ) correct_answer = self.choice_letters[correct_index] # Create judgment prompt system_content = self._create_system_content() user_content = self.create_judgment_prompt(prompt_text, answers) - prompt = tuple([ - frozenset({"role": "system", "content": system_content}.items()), - frozenset({"role": "user", "content": user_content}.items()), - ]) + prompt = tuple( + [ + frozenset({"role": "system", "content": system_content}.items()), + frozenset({"role": "user", "content": user_content}.items()), + ] + ) return (prompt, correct_answer)