[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2025-08-20 07:28:34 +00:00
parent a8c3e67062
commit 8b0a70131b

View file

@ -33,7 +33,7 @@ class TextReversalConfig(BaseEnvConfig):
default=None,
description="Custom thinking prompt. If None, uses the default thinking prompt.",
)
custom_thinking_prompt: Optional[str] = Field(
default=None,
description="Custom thinking prompt. If None, uses the default thinking prompt.",
@ -258,7 +258,7 @@ class TextReversalEnv(BaseEnv):
APIServerConfig(
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x"
api_key="x",
),
]
return env_config, server_configs
@ -304,35 +304,39 @@ class TextReversalEnv(BaseEnv):
if hasattr(self.train, "__iter__"):
total_train_items = len(self.train)
print(f"\nTraining dataset analysis ({total_train_items} total items):")
# Show some sample text lengths
text_lengths = []
for item in list(self.train)[:100]: # Sample first 100 items
text = item.get("text", "")
text_lengths.append(len(text))
if text_lengths:
avg_length = sum(text_lengths) / len(text_lengths)
min_length = min(text_lengths)
max_length = max(text_lengths)
print(f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}")
print(
f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}"
)
# Analyze evaluation dataset composition
if hasattr(self.test, "__iter__"):
total_eval_items = len(self.test)
print(f"\nEvaluation dataset analysis ({total_eval_items} total items):")
# Show some sample text lengths
text_lengths = []
for item in list(self.test)[:100]: # Sample first 100 items
text = item.get("text", "")
text_lengths.append(len(text))
if text_lengths:
avg_length = sum(text_lengths) / len(text_lengths)
min_length = min(text_lengths)
max_length = max(text_lengths)
print(f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}")
print(
f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}"
)
# Show configuration info
print("\nText Reversal Configuration:")
@ -361,9 +365,7 @@ class TextReversalEnv(BaseEnv):
print(
"\n🔍 FULL DEBUG MODE ENABLED - Will log all API requests and responses"
)
print(
" 📊 Will show: first/last 100 chars of prompts and responses"
)
print(" 📊 Will show: first/last 100 chars of prompts and responses")
print(
f" ⚙️ Retry settings: max_retries={self.config.max_retries}, retry_delay={self.config.retry_delay}s"
)
@ -394,28 +396,20 @@ class TextReversalEnv(BaseEnv):
# Local file - use appropriate loader based on extension
if dataset_path.endswith(".jsonl") or dataset_path.endswith(".json"):
dataset = load_dataset(
"json",
data_files=dataset_path,
split=split or "train"
"json", data_files=dataset_path, split=split or "train"
)
elif dataset_path.endswith(".csv"):
dataset = load_dataset(
"csv",
data_files=dataset_path,
split=split or "train"
"csv", data_files=dataset_path, split=split or "train"
)
elif dataset_path.endswith(".parquet"):
dataset = load_dataset(
"parquet",
data_files=dataset_path,
split=split or "train"
"parquet", data_files=dataset_path, split=split or "train"
)
else:
# Try JSON as default
dataset = load_dataset(
"json",
data_files=dataset_path,
split=split or "train"
"json", data_files=dataset_path, split=split or "train"
)
print(
@ -425,9 +419,7 @@ class TextReversalEnv(BaseEnv):
else:
# HuggingFace dataset
if split:
dataset = load_dataset(
dataset_path, split=split
)
dataset = load_dataset(dataset_path, split=split)
else:
dataset_dict = load_dataset(dataset_path)
# If no split specified, try to get the first available split
@ -463,10 +455,10 @@ class TextReversalEnv(BaseEnv):
def _extract_reversed_text(self, response: str) -> Optional[str]:
"""
Extract text from within <reversed_text> tags.
Args:
response: Model response text
Returns:
Extracted text or None if not found or multiple blocks found
"""
@ -487,11 +479,11 @@ class TextReversalEnv(BaseEnv):
# Find all content between <reversed_text> tags
matches = self._reversed_text_pattern.findall(response)
# Must have exactly one reversed_text block
if len(matches) != 1:
return None
return matches[0].strip()
def _create_reversal_prompt(self, text: str) -> str:
@ -539,7 +531,7 @@ class TextReversalEnv(BaseEnv):
"""
try:
original_text = item.get("text", "")
# Validate required fields
if not original_text:
return None, None
@ -723,7 +715,7 @@ class TextReversalEnv(BaseEnv):
# Extract reversed text from model response
extracted_reversed = self._extract_reversed_text(model_response)
# Score 1.0 if exact match, 0.0 otherwise
reward = 1.0 if extracted_reversed == expected_reversed else 0.0
@ -760,12 +752,18 @@ class TextReversalEnv(BaseEnv):
group_size = len(scores["scores"])
any_success = group_successes > 0
success_indicator = "" if any_success else ""
# Calculate running totals
total_success_rate = (self.successful_reversals / self.total_attempts * 100) if self.total_attempts > 0 else 0.0
print(f"{success_indicator} Group scored: {group_successes}/{group_size} successful reversals | "
f"Total success rate: {self.successful_reversals}/{self.total_attempts} ({total_success_rate:.1f}%)")
total_success_rate = (
(self.successful_reversals / self.total_attempts * 100)
if self.total_attempts > 0
else 0.0
)
print(
f"{success_indicator} Group scored: {group_successes}/{group_size} successful reversals | "
f"Total success rate: {self.successful_reversals}/{self.total_attempts} ({total_success_rate:.1f}%)"
)
# Update percent correct buffer
for score in scores["scores"]:
@ -898,7 +896,7 @@ class TextReversalEnv(BaseEnv):
# Extract reversed text from model response
extracted_reversed = self._extract_reversed_text(model_response)
# Score 1.0 if exact match, 0.0 otherwise
score = 1.0 if extracted_reversed == expected_reversed else 0.0
@ -956,8 +954,7 @@ class TextReversalEnv(BaseEnv):
try:
eval_tasks = [
self.rollout_and_score_eval(test_item)
for test_item in self.test
self.rollout_and_score_eval(test_item) for test_item in self.test
]
results = await tqdm_asyncio.gather(*eval_tasks)
@ -994,7 +991,7 @@ class TextReversalEnv(BaseEnv):
format_compliant = sum(
1 for sample in samples if sample.get("format_compliant", False)
)
thinking_mode_used = self.config.thinking_mode
# Get response metrics
@ -1104,16 +1101,18 @@ class TextReversalEnv(BaseEnv):
if role_dict_converted.get("role") == "user":
user_content = role_dict_converted.get("content", "")
# Extract original text from the user message
lines = user_content.split('\n')
lines = user_content.split("\n")
for line in lines:
if line.strip() and not line.startswith("Please reverse") and not line.startswith("The text to reverse:"):
if (
line.strip()
and not line.startswith("Please reverse")
and not line.startswith("The text to reverse:")
):
original_text = line.strip()
break
break
except Exception as e:
print(
f"DEBUG: Exception in add_rollouts_for_wandb text extraction: {e}"
)
print(f"DEBUG: Exception in add_rollouts_for_wandb text extraction: {e}")
original_text = "extraction_failed"
# Keep a reasonable number of rollouts
@ -1144,7 +1143,9 @@ class TextReversalEnv(BaseEnv):
# Fallback to decoding tokens
model_response = full_text
extracted_reversed = self._extract_reversed_text(model_response) or "format_error"
extracted_reversed = (
self._extract_reversed_text(model_response) or "format_error"
)
except Exception as e:
print(
f"DEBUG: Exception in add_rollouts_for_wandb reversal parsing: {e}"
@ -1206,10 +1207,18 @@ class TextReversalEnv(BaseEnv):
# Reversal-specific metrics
if self.total_attempts > 0:
wandb_metrics["train/success_rate"] = self.successful_reversals / self.total_attempts
wandb_metrics["train/failure_rate"] = self.failed_reversals / self.total_attempts
wandb_metrics["train/format_error_rate"] = self.format_errors / self.total_attempts
wandb_metrics["train/format_compliance_rate"] = 1.0 - (self.format_errors / self.total_attempts)
wandb_metrics["train/success_rate"] = (
self.successful_reversals / self.total_attempts
)
wandb_metrics["train/failure_rate"] = (
self.failed_reversals / self.total_attempts
)
wandb_metrics["train/format_error_rate"] = (
self.format_errors / self.total_attempts
)
wandb_metrics["train/format_compliance_rate"] = 1.0 - (
self.format_errors / self.total_attempts
)
# Configuration metrics
wandb_metrics.update(