mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
a8c3e67062
commit
8b0a70131b
1 changed files with 60 additions and 51 deletions
|
|
@ -33,7 +33,7 @@ class TextReversalConfig(BaseEnvConfig):
|
|||
default=None,
|
||||
description="Custom thinking prompt. If None, uses the default thinking prompt.",
|
||||
)
|
||||
|
||||
|
||||
custom_thinking_prompt: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Custom thinking prompt. If None, uses the default thinking prompt.",
|
||||
|
|
@ -258,7 +258,7 @@ class TextReversalEnv(BaseEnv):
|
|||
APIServerConfig(
|
||||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x"
|
||||
api_key="x",
|
||||
),
|
||||
]
|
||||
return env_config, server_configs
|
||||
|
|
@ -304,35 +304,39 @@ class TextReversalEnv(BaseEnv):
|
|||
if hasattr(self.train, "__iter__"):
|
||||
total_train_items = len(self.train)
|
||||
print(f"\nTraining dataset analysis ({total_train_items} total items):")
|
||||
|
||||
|
||||
# Show some sample text lengths
|
||||
text_lengths = []
|
||||
for item in list(self.train)[:100]: # Sample first 100 items
|
||||
text = item.get("text", "")
|
||||
text_lengths.append(len(text))
|
||||
|
||||
|
||||
if text_lengths:
|
||||
avg_length = sum(text_lengths) / len(text_lengths)
|
||||
min_length = min(text_lengths)
|
||||
max_length = max(text_lengths)
|
||||
print(f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}")
|
||||
print(
|
||||
f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}"
|
||||
)
|
||||
|
||||
# Analyze evaluation dataset composition
|
||||
if hasattr(self.test, "__iter__"):
|
||||
total_eval_items = len(self.test)
|
||||
print(f"\nEvaluation dataset analysis ({total_eval_items} total items):")
|
||||
|
||||
|
||||
# Show some sample text lengths
|
||||
text_lengths = []
|
||||
for item in list(self.test)[:100]: # Sample first 100 items
|
||||
text = item.get("text", "")
|
||||
text_lengths.append(len(text))
|
||||
|
||||
|
||||
if text_lengths:
|
||||
avg_length = sum(text_lengths) / len(text_lengths)
|
||||
min_length = min(text_lengths)
|
||||
max_length = max(text_lengths)
|
||||
print(f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}")
|
||||
print(
|
||||
f" - Sample text lengths: avg={avg_length:.1f}, min={min_length}, max={max_length}"
|
||||
)
|
||||
|
||||
# Show configuration info
|
||||
print("\nText Reversal Configuration:")
|
||||
|
|
@ -361,9 +365,7 @@ class TextReversalEnv(BaseEnv):
|
|||
print(
|
||||
"\n🔍 FULL DEBUG MODE ENABLED - Will log all API requests and responses"
|
||||
)
|
||||
print(
|
||||
" 📊 Will show: first/last 100 chars of prompts and responses"
|
||||
)
|
||||
print(" 📊 Will show: first/last 100 chars of prompts and responses")
|
||||
print(
|
||||
f" ⚙️ Retry settings: max_retries={self.config.max_retries}, retry_delay={self.config.retry_delay}s"
|
||||
)
|
||||
|
|
@ -394,28 +396,20 @@ class TextReversalEnv(BaseEnv):
|
|||
# Local file - use appropriate loader based on extension
|
||||
if dataset_path.endswith(".jsonl") or dataset_path.endswith(".json"):
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files=dataset_path,
|
||||
split=split or "train"
|
||||
"json", data_files=dataset_path, split=split or "train"
|
||||
)
|
||||
elif dataset_path.endswith(".csv"):
|
||||
dataset = load_dataset(
|
||||
"csv",
|
||||
data_files=dataset_path,
|
||||
split=split or "train"
|
||||
"csv", data_files=dataset_path, split=split or "train"
|
||||
)
|
||||
elif dataset_path.endswith(".parquet"):
|
||||
dataset = load_dataset(
|
||||
"parquet",
|
||||
data_files=dataset_path,
|
||||
split=split or "train"
|
||||
"parquet", data_files=dataset_path, split=split or "train"
|
||||
)
|
||||
else:
|
||||
# Try JSON as default
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
data_files=dataset_path,
|
||||
split=split or "train"
|
||||
"json", data_files=dataset_path, split=split or "train"
|
||||
)
|
||||
|
||||
print(
|
||||
|
|
@ -425,9 +419,7 @@ class TextReversalEnv(BaseEnv):
|
|||
else:
|
||||
# HuggingFace dataset
|
||||
if split:
|
||||
dataset = load_dataset(
|
||||
dataset_path, split=split
|
||||
)
|
||||
dataset = load_dataset(dataset_path, split=split)
|
||||
else:
|
||||
dataset_dict = load_dataset(dataset_path)
|
||||
# If no split specified, try to get the first available split
|
||||
|
|
@ -463,10 +455,10 @@ class TextReversalEnv(BaseEnv):
|
|||
def _extract_reversed_text(self, response: str) -> Optional[str]:
|
||||
"""
|
||||
Extract text from within <reversed_text> tags.
|
||||
|
||||
|
||||
Args:
|
||||
response: Model response text
|
||||
|
||||
|
||||
Returns:
|
||||
Extracted text or None if not found or multiple blocks found
|
||||
"""
|
||||
|
|
@ -487,11 +479,11 @@ class TextReversalEnv(BaseEnv):
|
|||
|
||||
# Find all content between <reversed_text> tags
|
||||
matches = self._reversed_text_pattern.findall(response)
|
||||
|
||||
|
||||
# Must have exactly one reversed_text block
|
||||
if len(matches) != 1:
|
||||
return None
|
||||
|
||||
|
||||
return matches[0].strip()
|
||||
|
||||
def _create_reversal_prompt(self, text: str) -> str:
|
||||
|
|
@ -539,7 +531,7 @@ class TextReversalEnv(BaseEnv):
|
|||
"""
|
||||
try:
|
||||
original_text = item.get("text", "")
|
||||
|
||||
|
||||
# Validate required fields
|
||||
if not original_text:
|
||||
return None, None
|
||||
|
|
@ -723,7 +715,7 @@ class TextReversalEnv(BaseEnv):
|
|||
|
||||
# Extract reversed text from model response
|
||||
extracted_reversed = self._extract_reversed_text(model_response)
|
||||
|
||||
|
||||
# Score 1.0 if exact match, 0.0 otherwise
|
||||
reward = 1.0 if extracted_reversed == expected_reversed else 0.0
|
||||
|
||||
|
|
@ -760,12 +752,18 @@ class TextReversalEnv(BaseEnv):
|
|||
group_size = len(scores["scores"])
|
||||
any_success = group_successes > 0
|
||||
success_indicator = "✅" if any_success else "❌"
|
||||
|
||||
|
||||
# Calculate running totals
|
||||
total_success_rate = (self.successful_reversals / self.total_attempts * 100) if self.total_attempts > 0 else 0.0
|
||||
|
||||
print(f"{success_indicator} Group scored: {group_successes}/{group_size} successful reversals | "
|
||||
f"Total success rate: {self.successful_reversals}/{self.total_attempts} ({total_success_rate:.1f}%)")
|
||||
total_success_rate = (
|
||||
(self.successful_reversals / self.total_attempts * 100)
|
||||
if self.total_attempts > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
print(
|
||||
f"{success_indicator} Group scored: {group_successes}/{group_size} successful reversals | "
|
||||
f"Total success rate: {self.successful_reversals}/{self.total_attempts} ({total_success_rate:.1f}%)"
|
||||
)
|
||||
|
||||
# Update percent correct buffer
|
||||
for score in scores["scores"]:
|
||||
|
|
@ -898,7 +896,7 @@ class TextReversalEnv(BaseEnv):
|
|||
|
||||
# Extract reversed text from model response
|
||||
extracted_reversed = self._extract_reversed_text(model_response)
|
||||
|
||||
|
||||
# Score 1.0 if exact match, 0.0 otherwise
|
||||
score = 1.0 if extracted_reversed == expected_reversed else 0.0
|
||||
|
||||
|
|
@ -956,8 +954,7 @@ class TextReversalEnv(BaseEnv):
|
|||
|
||||
try:
|
||||
eval_tasks = [
|
||||
self.rollout_and_score_eval(test_item)
|
||||
for test_item in self.test
|
||||
self.rollout_and_score_eval(test_item) for test_item in self.test
|
||||
]
|
||||
results = await tqdm_asyncio.gather(*eval_tasks)
|
||||
|
||||
|
|
@ -994,7 +991,7 @@ class TextReversalEnv(BaseEnv):
|
|||
format_compliant = sum(
|
||||
1 for sample in samples if sample.get("format_compliant", False)
|
||||
)
|
||||
|
||||
|
||||
thinking_mode_used = self.config.thinking_mode
|
||||
|
||||
# Get response metrics
|
||||
|
|
@ -1104,16 +1101,18 @@ class TextReversalEnv(BaseEnv):
|
|||
if role_dict_converted.get("role") == "user":
|
||||
user_content = role_dict_converted.get("content", "")
|
||||
# Extract original text from the user message
|
||||
lines = user_content.split('\n')
|
||||
lines = user_content.split("\n")
|
||||
for line in lines:
|
||||
if line.strip() and not line.startswith("Please reverse") and not line.startswith("The text to reverse:"):
|
||||
if (
|
||||
line.strip()
|
||||
and not line.startswith("Please reverse")
|
||||
and not line.startswith("The text to reverse:")
|
||||
):
|
||||
original_text = line.strip()
|
||||
break
|
||||
break
|
||||
except Exception as e:
|
||||
print(
|
||||
f"DEBUG: Exception in add_rollouts_for_wandb text extraction: {e}"
|
||||
)
|
||||
print(f"DEBUG: Exception in add_rollouts_for_wandb text extraction: {e}")
|
||||
original_text = "extraction_failed"
|
||||
|
||||
# Keep a reasonable number of rollouts
|
||||
|
|
@ -1144,7 +1143,9 @@ class TextReversalEnv(BaseEnv):
|
|||
# Fallback to decoding tokens
|
||||
model_response = full_text
|
||||
|
||||
extracted_reversed = self._extract_reversed_text(model_response) or "format_error"
|
||||
extracted_reversed = (
|
||||
self._extract_reversed_text(model_response) or "format_error"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"DEBUG: Exception in add_rollouts_for_wandb reversal parsing: {e}"
|
||||
|
|
@ -1206,10 +1207,18 @@ class TextReversalEnv(BaseEnv):
|
|||
|
||||
# Reversal-specific metrics
|
||||
if self.total_attempts > 0:
|
||||
wandb_metrics["train/success_rate"] = self.successful_reversals / self.total_attempts
|
||||
wandb_metrics["train/failure_rate"] = self.failed_reversals / self.total_attempts
|
||||
wandb_metrics["train/format_error_rate"] = self.format_errors / self.total_attempts
|
||||
wandb_metrics["train/format_compliance_rate"] = 1.0 - (self.format_errors / self.total_attempts)
|
||||
wandb_metrics["train/success_rate"] = (
|
||||
self.successful_reversals / self.total_attempts
|
||||
)
|
||||
wandb_metrics["train/failure_rate"] = (
|
||||
self.failed_reversals / self.total_attempts
|
||||
)
|
||||
wandb_metrics["train/format_error_rate"] = (
|
||||
self.format_errors / self.total_attempts
|
||||
)
|
||||
wandb_metrics["train/format_compliance_rate"] = 1.0 - (
|
||||
self.format_errors / self.total_attempts
|
||||
)
|
||||
|
||||
# Configuration metrics
|
||||
wandb_metrics.update(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue