mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
Other major changes
This commit is contained in:
parent
b56d03b25c
commit
ec09a1caee
1 changed files with 7 additions and 43 deletions
|
|
@ -148,7 +148,7 @@ class RefusalBenchConfig(BaseEnvConfig):
|
|||
)
|
||||
|
||||
min_response_length: int = Field(
|
||||
default=10,
|
||||
default=1,
|
||||
ge=1,
|
||||
description="Minimum response length to consider valid.",
|
||||
)
|
||||
|
|
@ -276,7 +276,6 @@ class RefusalBenchEnv(BaseEnv):
|
|||
try:
|
||||
# Check if it's a local file
|
||||
if os.path.exists(dataset_path):
|
||||
# Local file - use appropriate loader based on extension
|
||||
if dataset_path.endswith(".jsonl") or dataset_path.endswith(".json"):
|
||||
dataset = load_dataset(
|
||||
"json",
|
||||
|
|
@ -379,9 +378,9 @@ class RefusalBenchEnv(BaseEnv):
|
|||
)
|
||||
server_configs = [
|
||||
APIServerConfig(
|
||||
model_name="Hermes-3-Llama-3.1-405B",
|
||||
base_url="https://inference-api.nousresearch.com/v1",
|
||||
api_key=os.getenv("NOUS_API_KEY"),
|
||||
model_name="gpt-4o",
|
||||
base_url="https://api.openai.com/v1",
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
num_max_requests_at_once=32,
|
||||
num_requests_for_eval=256,
|
||||
),
|
||||
|
|
@ -423,20 +422,6 @@ class RefusalBenchEnv(BaseEnv):
|
|||
# Analyze datasets
|
||||
self._analyze_datasets()
|
||||
|
||||
# Show configuration info
|
||||
print("\nRefusalBench Configuration:")
|
||||
print(
|
||||
f" - Training dataset: {self.config.train_dataset} (split: {self.config.train_split})"
|
||||
)
|
||||
print(
|
||||
f" - Evaluation dataset: {self.config.eval_dataset} (split: {self.config.eval_split})"
|
||||
)
|
||||
print(f" - Thinking mode: {self.config.thinking_mode}")
|
||||
print(f" - Judge model: {self.config.judge_model_name}")
|
||||
print(f" - Judge rate limiting: max {self.config.judge_max_concurrent_calls} concurrent, {self.config.judge_rate_limit_delay}s delay")
|
||||
print(f" - Fallback scoring: {self.config.use_fallback_scoring}")
|
||||
print(f" - Inverted categories: {self.config.inverted_categories}")
|
||||
|
||||
self.iter = 0
|
||||
|
||||
def _analyze_datasets(self):
|
||||
|
|
@ -448,7 +433,7 @@ class RefusalBenchEnv(BaseEnv):
|
|||
|
||||
for item in self.train:
|
||||
total_train_items += 1
|
||||
category = item.get("overarching_category", "Unknown")
|
||||
category = item.get("overarching_category", "Unknown") #The category needs to be called overarching_category. Otherwise it wont work
|
||||
train_category_counts[category] = (
|
||||
train_category_counts.get(category, 0) + 1
|
||||
)
|
||||
|
|
@ -1030,6 +1015,7 @@ Your answer:"""
|
|||
async def rollout_and_score_eval(self, eval_item: dict) -> dict:
|
||||
"""Rollout and score evaluation."""
|
||||
try:
|
||||
# Again the correct labels are needed here
|
||||
prompt_text = eval_item.get("prompt", "")
|
||||
overarching_category = eval_item.get("overarching_category", "Unknown")
|
||||
|
||||
|
|
@ -1053,33 +1039,11 @@ Your answer:"""
|
|||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
print(f" Prompt: {prompt_text[:100]}...")
|
||||
print(f" Category: {overarching_category}")
|
||||
print(f" Danger: {danger}")
|
||||
print(f" Completion params: {completion_params}")
|
||||
|
||||
# Log server configuration
|
||||
if hasattr(self.server, 'servers') and self.server.servers:
|
||||
server = self.server.servers[0]
|
||||
print(f" Server model: {server.config.model_name}")
|
||||
print(f" Server URL: {server.config.base_url}")
|
||||
print(f" Server healthy: {server.server_healthy}")
|
||||
print(f" Server sem value: {server.sem._value}")
|
||||
print(f" Server eval_sem value: {server.eval_sem._value}")
|
||||
|
||||
completion = await self.server.chat_completion(
|
||||
messages=messages, **completion_params
|
||||
)
|
||||
|
||||
print(f"✅ Got completion response!")
|
||||
print(f" Completion type: {type(completion)}")
|
||||
print(f" Has choices: {hasattr(completion, 'choices')}")
|
||||
if hasattr(completion, 'choices'):
|
||||
print(f" Number of choices: {len(completion.choices)}")
|
||||
if completion.choices:
|
||||
print(f" First choice has message: {hasattr(completion.choices[0], 'message')}")
|
||||
if hasattr(completion.choices[0], 'message'):
|
||||
print(f" First choice content length: {len(completion.choices[0].message.content) if completion.choices[0].message.content else 0}")
|
||||
|
||||
# Log full debug response
|
||||
self._log_full_debug_response(
|
||||
|
|
@ -1332,7 +1296,7 @@ Your answer:"""
|
|||
}
|
||||
|
||||
async def evaluate(self, *args, **kwargs) -> None:
|
||||
"""Evaluate the model on the test dataset with optional averaging over multiple runs."""
|
||||
"""Evaluate the model on the test dataset."""
|
||||
if self.config.eval_runs_for_average == 1:
|
||||
# Single run - use the original behavior
|
||||
result = await self._evaluate_single_run(1, *args, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue