diff --git a/environments/text_reversal_environment.py b/environments/text_reversal_environment.py index eb875dc1..63624139 100644 --- a/environments/text_reversal_environment.py +++ b/environments/text_reversal_environment.py @@ -33,7 +33,7 @@ class TextReversalConfig(BaseEnvConfig): default=None, description="Custom thinking prompt. If None, uses the default thinking prompt.", ) - + custom_thinking_prompt: Optional[str] = Field( default=None, description="Custom thinking prompt. If None, uses the default thinking prompt.", @@ -258,7 +258,7 @@ class TextReversalEnv(BaseEnv): APIServerConfig( model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview", base_url="http://localhost:9004/v1", - api_key="x" + api_key="x", ), ] return env_config, server_configs @@ -304,35 +304,39 @@ class TextReversalEnv(BaseEnv): if hasattr(self.train, "__iter__"): total_train_items = len(self.train) print(f"\nTraining dataset analysis ({total_train_items} total items):") - + # Show some sample text lengths text_lengths = [] for item in list(self.train)[:100]: # Sample first 100 items text = item.get("text", "") text_lengths.append(len(text)) - + if text_lengths: avg_length = sum(text_lengths) / len(text_lengths) min_length = min(text_lengths) max_length = max(text_lengths) - print(f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}") + print( + f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}" + ) # Analyze evaluation dataset composition if hasattr(self.test, "__iter__"): total_eval_items = len(self.test) print(f"\nEvaluation dataset analysis ({total_eval_items} total items):") - + # Show some sample text lengths text_lengths = [] for item in list(self.test)[:100]: # Sample first 100 items text = item.get("text", "") text_lengths.append(len(text)) - + if text_lengths: avg_length = sum(text_lengths) / len(text_lengths) min_length = min(text_lengths) max_length = max(text_lengths) - print(f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}") + print( + f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}" + ) # Show configuration info print("\nText Reversal Configuration:") @@ -361,9 +365,7 @@ class TextReversalEnv(BaseEnv): print( "\nšŸ” FULL DEBUG MODE ENABLED - Will log all API requests and responses" ) - print( - " šŸ“Š Will show: first/last 100 chars of prompts and responses" - ) + print(" šŸ“Š Will show: first/last 100 chars of prompts and responses") print( f" āš™ļø Retry settings: max_retries={self.config.max_retries}, retry_delay={self.config.retry_delay}s" ) @@ -394,28 +396,20 @@ class TextReversalEnv(BaseEnv): # Local file - use appropriate loader based on extension if dataset_path.endswith(".jsonl") or dataset_path.endswith(".json"): dataset = load_dataset( - "json", - data_files=dataset_path, - split=split or "train" + "json", data_files=dataset_path, split=split or "train" ) elif dataset_path.endswith(".csv"): dataset = load_dataset( - "csv", - data_files=dataset_path, - split=split or "train" + "csv", data_files=dataset_path, split=split or "train" ) elif dataset_path.endswith(".parquet"): dataset = load_dataset( - "parquet", - data_files=dataset_path, - split=split or "train" + "parquet", data_files=dataset_path, split=split or "train" ) else: # Try JSON as default dataset = load_dataset( - "json", - data_files=dataset_path, - split=split or "train" + "json", data_files=dataset_path, split=split or "train" ) print( @@ -425,9 +419,7 @@ class TextReversalEnv(BaseEnv): else: # HuggingFace dataset if split: - dataset = load_dataset( - dataset_path, split=split - ) + dataset = load_dataset(dataset_path, split=split) else: dataset_dict = load_dataset(dataset_path) # If no split specified, try to get the first available split @@ -463,10 +455,10 @@ class TextReversalEnv(BaseEnv): def _extract_reversed_text(self, response: str) -> Optional[str]: """ Extract text from within tags. - + Args: response: Model response text - + Returns: Extracted text or None if not found or multiple blocks found """ @@ -487,11 +479,11 @@ class TextReversalEnv(BaseEnv): # Find all content between tags matches = self._reversed_text_pattern.findall(response) - + # Must have exactly one reversed_text block if len(matches) != 1: return None - + return matches[0].strip() def _create_reversal_prompt(self, text: str) -> str: @@ -539,7 +531,7 @@ class TextReversalEnv(BaseEnv): """ try: original_text = item.get("text", "") - + # Validate required fields if not original_text: return None, None @@ -723,7 +715,7 @@ class TextReversalEnv(BaseEnv): # Extract reversed text from model response extracted_reversed = self._extract_reversed_text(model_response) - + # Score 1.0 if exact match, 0.0 otherwise reward = 1.0 if extracted_reversed == expected_reversed else 0.0 @@ -760,12 +752,18 @@ class TextReversalEnv(BaseEnv): group_size = len(scores["scores"]) any_success = group_successes > 0 success_indicator = "āœ…" if any_success else "āŒ" - + # Calculate running totals - total_success_rate = (self.successful_reversals / self.total_attempts * 100) if self.total_attempts > 0 else 0.0 - - print(f"{success_indicator} Group scored: {group_successes}/{group_size} successful reversals | " - f"Total success rate: {self.successful_reversals}/{self.total_attempts} ({total_success_rate:.1f}%)") + total_success_rate = ( + (self.successful_reversals / self.total_attempts * 100) + if self.total_attempts > 0 + else 0.0 + ) + + print( + f"{success_indicator} Group scored: {group_successes}/{group_size} successful reversals | " + f"Total success rate: {self.successful_reversals}/{self.total_attempts} ({total_success_rate:.1f}%)" + ) # Update percent correct buffer for score in scores["scores"]: @@ -898,7 +896,7 @@ class TextReversalEnv(BaseEnv): # Extract reversed text from model response extracted_reversed = self._extract_reversed_text(model_response) - + # Score 1.0 if exact match, 0.0 otherwise score = 1.0 if extracted_reversed == expected_reversed else 0.0 @@ -956,8 +954,7 @@ class TextReversalEnv(BaseEnv): try: eval_tasks = [ - self.rollout_and_score_eval(test_item) - for test_item in self.test + self.rollout_and_score_eval(test_item) for test_item in self.test ] results = await tqdm_asyncio.gather(*eval_tasks) @@ -994,7 +991,7 @@ class TextReversalEnv(BaseEnv): format_compliant = sum( 1 for sample in samples if sample.get("format_compliant", False) ) - + thinking_mode_used = self.config.thinking_mode # Get response metrics @@ -1104,16 +1101,18 @@ class TextReversalEnv(BaseEnv): if role_dict_converted.get("role") == "user": user_content = role_dict_converted.get("content", "") # Extract original text from the user message - lines = user_content.split('\n') + lines = user_content.split("\n") for line in lines: - if line.strip() and not line.startswith("Please reverse") and not line.startswith("The text to reverse:"): + if ( + line.strip() + and not line.startswith("Please reverse") + and not line.startswith("The text to reverse:") + ): original_text = line.strip() break break except Exception as e: - print( - f"DEBUG: Exception in add_rollouts_for_wandb text extraction: {e}" - ) + print(f"DEBUG: Exception in add_rollouts_for_wandb text extraction: {e}") original_text = "extraction_failed" # Keep a reasonable number of rollouts @@ -1144,7 +1143,9 @@ class TextReversalEnv(BaseEnv): # Fallback to decoding tokens model_response = full_text - extracted_reversed = self._extract_reversed_text(model_response) or "format_error" + extracted_reversed = ( + self._extract_reversed_text(model_response) or "format_error" + ) except Exception as e: print( f"DEBUG: Exception in add_rollouts_for_wandb reversal parsing: {e}" @@ -1206,10 +1207,18 @@ class TextReversalEnv(BaseEnv): # Reversal-specific metrics if self.total_attempts > 0: - wandb_metrics["train/success_rate"] = self.successful_reversals / self.total_attempts - wandb_metrics["train/failure_rate"] = self.failed_reversals / self.total_attempts - wandb_metrics["train/format_error_rate"] = self.format_errors / self.total_attempts - wandb_metrics["train/format_compliance_rate"] = 1.0 - (self.format_errors / self.total_attempts) + wandb_metrics["train/success_rate"] = ( + self.successful_reversals / self.total_attempts + ) + wandb_metrics["train/failure_rate"] = ( + self.failed_reversals / self.total_attempts + ) + wandb_metrics["train/format_error_rate"] = ( + self.format_errors / self.total_attempts + ) + wandb_metrics["train/format_compliance_rate"] = 1.0 - ( + self.format_errors / self.total_attempts + ) # Configuration metrics wandb_metrics.update(