mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-25 17:10:42 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
ef9c0c3699
commit
afab28dfa9
37 changed files with 4868 additions and 4052 deletions
|
|
@ -6,6 +6,11 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
import openai
|
||||
from datasets import load_dataset
|
||||
from eval_helpers import (
|
||||
create_system_content,
|
||||
get_default_thinking_prompt,
|
||||
save_eval_results,
|
||||
)
|
||||
from pydantic import Field
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
|
@ -19,11 +24,6 @@ from atroposlib.envs.base import (
|
|||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
from eval_helpers import (
|
||||
get_default_thinking_prompt,
|
||||
create_system_content,
|
||||
save_eval_results,
|
||||
)
|
||||
|
||||
|
||||
class ArenaHardConfig(BaseEnvConfig):
|
||||
|
|
@ -244,13 +244,13 @@ class ArenaHardEnv(BaseEnv):
|
|||
def _get_system_prompt(self) -> Optional[str]:
|
||||
"""Get system prompt for non-thinking mode."""
|
||||
return self.config.custom_system_prompt
|
||||
|
||||
|
||||
def _create_system_content(self) -> Optional[str]:
|
||||
"""Create system message content based on thinking mode."""
|
||||
return create_system_content(
|
||||
self.config.thinking_mode,
|
||||
self.config.custom_thinking_prompt,
|
||||
self.config.custom_system_prompt
|
||||
self.config.custom_system_prompt,
|
||||
)
|
||||
|
||||
def _load_dataset(self, dataset_path: str, split: str = None) -> List[Dict]:
|
||||
|
|
@ -562,7 +562,7 @@ class ArenaHardEnv(BaseEnv):
|
|||
debug_params = {"temperature": self.config.judge_temperature}
|
||||
if self.config.judge_max_tokens > 0:
|
||||
debug_params["max_tokens"] = self.config.judge_max_tokens
|
||||
|
||||
|
||||
self._log_full_debug_request(
|
||||
messages,
|
||||
debug_params,
|
||||
|
|
@ -576,7 +576,7 @@ class ArenaHardEnv(BaseEnv):
|
|||
}
|
||||
if self.config.judge_max_tokens > 0:
|
||||
kwargs["max_tokens"] = self.config.judge_max_tokens
|
||||
|
||||
|
||||
completion = await self.judge_client.chat.completions.create(**kwargs)
|
||||
|
||||
self._log_full_debug_response(completion, "JUDGE API CALL")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue