Add support for reasoning models and their variety of providers/endpoints

This commit is contained in:
teknium 2025-12-30 00:23:00 +00:00
parent 1c306d3b17
commit 62fa51240c
6 changed files with 1551 additions and 16 deletions

View file

@ -41,6 +41,7 @@ from atroposlib.utils.io import parse_http_response
from atroposlib.utils.metrics import get_std_min_max_avg
from ..type_definitions import Item, Message
from .server_handling.server_baseline import ReasoningConfig
from .server_handling.server_manager import (
APIServer,
APIServerConfig,
@ -178,6 +179,31 @@ class BaseEnvConfig(BaseModel):
default=600,
description="Timeout for a a task, in seconds, if -1, no timeout",
)
thinking_mode: bool = Field(
default=False,
description="Whether to enable reasoning/thinking mode in API requests. "
"When True, requests include extra_body parameters to trigger model reasoning. "
"Automatically set to True if reasoning_effort or max_reasoning_tokens are specified.",
)
reasoning_effort: Optional[str] = Field(
default=None,
description="Reasoning effort level. Valid values: 'none', 'minimal', 'low', "
"'medium', 'high', 'xhigh'. For OpenAI models, values are mapped to their "
"supported levels ('low', 'medium', 'high'). Default None (not specified).",
)
max_reasoning_tokens: Optional[int] = Field(
default=None,
ge=1024,
le=32000,
description="Maximum tokens for reasoning (1024-32000). Only supported by "
"some providers (not OpenAI official). Default None (not specified).",
)
custom_thinking_prompt: Optional[str] = Field(
default=None,
description="Custom system prompt to prepend for thinking mode. If None, "
"no thinking prompt is injected. Use HERMES_REASONING_PROMPT from "
"eval_helpers for the standard Hermes reasoning prompt.",
)
class BaseEnv(ABC):
@ -203,8 +229,20 @@ class BaseEnv(ABC):
self.last_loop_time = None
self.last_completed_item = None
self.config = config
# Build reasoning config from env config fields
reasoning_config = ReasoningConfig(
enabled=config.thinking_mode,
effort=config.reasoning_effort,
max_tokens=config.max_reasoning_tokens,
)
self.server = ServerManager(
server_configs, slurm=slurm, testing=testing, server_class=self.server_cls
server_configs,
slurm=slurm,
testing=testing,
server_class=self.server_cls,
reasoning_config=reasoning_config,
)
self.workers = set()
self.eval_workers = set()