From 6bc962c746e4856800d925c14b923401cbb30d6e Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 16 Feb 2026 11:46:20 -0500 Subject: [PATCH 01/23] initial commit --- atroposlib/api/server.py | 4 + atroposlib/envs/base.py | 230 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+) diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 4a94c6d8..f6b1d4d1 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -145,6 +145,9 @@ class ScoredData(BaseModel): group_overrides: Optional[dict] = None images: Optional[Any] = None env_id: Optional[int] = None # ID of the environment that generated this data + # On-policy distillation: top-K logprobs from teacher model + # Structure: [sequence][position][top_k] = [token_id, logprob] + onpolicydistill_logprobs: Optional[List[List[List[List]]]] = None @field_validator("messages", mode="before") @classmethod @@ -182,6 +185,7 @@ def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]: "group_overrides": scored_data.group_overrides, "images": scored_data.images, "env_id": scored_data.env_id, + "onpolicydistill_logprobs": scored_data.onpolicydistill_logprobs, } diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index e9b672a6..81569f0d 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -66,6 +66,9 @@ class ScoredDataGroup(TypedDict): group_overrides: Optional[Dict] overrides: Optional[List[Dict]] images: Optional[Any] + # On-policy distillation: top-K logprobs from teacher model + # Structure: List[List[List[List[Union[int, float]]]]] = [sequence][position][top_k] = [token_id, logprob] + onpolicydistill_logprobs: Optional[List[List[List[List]]]] class ScoredDataItem(TypedDict): @@ -78,6 +81,8 @@ class ScoredDataItem(TypedDict): group_overrides: Optional[Dict] overrides: Optional[Dict] images: Optional[Any] + # On-policy distillation: top-K logprobs from teacher model per position + onpolicydistill_logprobs: Optional[List[List[List]]] class EvalHandlingEnum(Enum): @@ -205,6 +210,31 @@ class BaseEnvConfig(BaseModel): "eval_helpers for the standard Hermes reasoning prompt.", ) + # On-policy distillation settings + teacher_base_url: Optional[str] = Field( + default=None, + description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API " + "(vLLM, OpenAI, Together, etc.). Examples: 'http://localhost:8001/v1', 'https://api.openai.com/v1'", + ) + teacher_model_name: Optional[str] = Field( + default=None, + description="Model name for teacher API calls (e.g., 'gpt-4o', 'meta-llama/Llama-3-70b'). " + "If None, uses 'default' which works for single-model vLLM servers.", + ) + teacher_api_key: Optional[str] = Field( + default=None, + description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.", + ) + teacher_top_k: int = Field( + default=10, + description="Number of top logprobs to fetch from teacher model per position.", + ) + distillation_enabled: bool = Field( + default=False, + description="Enable on-policy distillation. When True, automatically fetches teacher logprobs " + "after scoring and includes them in data sent to trainer.", + ) + class BaseEnv(ABC): name: Optional[str] = None @@ -309,6 +339,190 @@ class BaseEnv(ABC): # Calculate derived batch size return int(self.config.batch_size * effective_fraction) + async def get_teacher_logprobs( + self, + token_sequences: List[List[int]], + messages_list: Optional[List[List[Dict]]] = None, + top_k: Optional[int] = None, + ) -> List[List[List[List]]]: + """ + Fetch top-K logprobs from teacher model for given sequences. + + Supports any OpenAI-compatible API (vLLM, OpenAI, Together, etc.). + + Args: + token_sequences: List of token ID sequences to get logprobs for + messages_list: Optional list of message histories (for chat APIs). + If provided, uses chat/completions with logprobs. + top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k) + + Returns: + List of top-K logprobs per position per sequence + Structure: [batch][position][top_k] = [token_id, logprob] + Returns empty list if teacher_base_url is not configured. + """ + if not self.config.teacher_base_url: + return [] + + if top_k is None: + top_k = self.config.teacher_top_k + + # Get API key from config or environment + api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "") + model_name = self.config.teacher_model_name or "default" + + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + results = [] + + try: + async with aiohttp.ClientSession() as session: + for i, tokens in enumerate(token_sequences): + # Decode tokens to text + full_text = self.tokenizer.decode(tokens, skip_special_tokens=False) + + # Try vLLM-style completions first (supports prompt_logprobs) + # This is most efficient as it doesn't generate new tokens + request_data = { + "model": model_name, + "prompt": full_text, + "max_tokens": 1, + "temperature": 1.0, + "logprobs": top_k, + "echo": True, # Include prompt in response with logprobs + } + + try: + async with session.post( + f"{self.config.teacher_base_url}/completions", + json=request_data, + headers=headers, + timeout=aiohttp.ClientTimeout(total=120), + ) as response: + if response.status == 200: + data = await response.json() + seq_result = self._parse_completion_logprobs(data, top_k) + if seq_result: + results.append(seq_result) + continue + except Exception: + pass # Fall through to chat completions + + # Fallback: Use chat/completions with logprobs (OpenAI style) + # This requires messages format + if messages_list and i < len(messages_list): + messages = messages_list[i] + else: + # Convert text to simple message format + messages = [{"role": "user", "content": full_text}] + + chat_request = { + "model": model_name, + "messages": messages, + "max_tokens": 1, + "temperature": 1.0, + "logprobs": True, + "top_logprobs": top_k, + } + + try: + async with session.post( + f"{self.config.teacher_base_url}/chat/completions", + json=chat_request, + headers=headers, + timeout=aiohttp.ClientTimeout(total=120), + ) as response: + if response.status == 200: + data = await response.json() + seq_result = self._parse_chat_logprobs(data, top_k) + results.append(seq_result) + else: + logger.warning(f"Teacher API returned {response.status}") + results.append([]) + except Exception as e: + logger.warning(f"Teacher chat request failed: {e}") + results.append([]) + + return results + + except Exception as e: + logger.error(f"Error fetching teacher logprobs: {e}") + return [] + + def _parse_completion_logprobs( + self, data: Dict, top_k: int + ) -> List[List[List]]: + """Parse logprobs from vLLM-style completion response.""" + try: + choice = data.get("choices", [{}])[0] + logprobs_data = choice.get("logprobs", {}) + + # vLLM returns top_logprobs as list of dicts + top_logprobs = logprobs_data.get("top_logprobs", []) + tokens = logprobs_data.get("tokens", []) + + if not top_logprobs: + return [] + + seq_result = [] + for pos_logprobs in top_logprobs: + if pos_logprobs is None: + seq_result.append([]) + elif isinstance(pos_logprobs, dict): + # Format: {token_str: logprob, ...} + sorted_items = sorted( + pos_logprobs.items(), + key=lambda x: x[1], + reverse=True + )[:top_k] + pos_result = [] + for token_str, logprob in sorted_items: + # Convert token string to ID + token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) + if token_ids: + pos_result.append([token_ids[0], float(logprob)]) + seq_result.append(pos_result) + else: + seq_result.append([]) + + return seq_result + except Exception as e: + logger.warning(f"Error parsing completion logprobs: {e}") + return [] + + def _parse_chat_logprobs( + self, data: Dict, top_k: int + ) -> List[List[List]]: + """Parse logprobs from OpenAI-style chat completion response.""" + try: + choice = data.get("choices", [{}])[0] + logprobs_data = choice.get("logprobs", {}) + + if not logprobs_data: + return [] + + content = logprobs_data.get("content", []) + seq_result = [] + + for token_data in content: + top_logprobs = token_data.get("top_logprobs", []) + pos_result = [] + for item in top_logprobs[:top_k]: + token_str = item.get("token", "") + logprob = item.get("logprob", 0.0) + # Convert token string to ID + token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) + if token_ids: + pos_result.append([token_ids[0], float(logprob)]) + seq_result.append(pos_result) + + return seq_result + except Exception as e: + logger.warning(f"Error parsing chat logprobs: {e}") + return [] + @classmethod def config_init( cls, @@ -909,6 +1123,22 @@ class BaseEnv(ABC): for i in range(len(group["tokens"])) ] + # Automatic on-policy distillation: fetch teacher logprobs if enabled + if self.config.distillation_enabled and self.config.teacher_base_url: + if group.get("onpolicydistill_logprobs") is None: + try: + teacher_logprobs = await self.get_teacher_logprobs( + token_sequences=group["tokens"], + messages_list=group.get("messages"), + ) + if teacher_logprobs: + group["onpolicydistill_logprobs"] = teacher_logprobs + logger.debug( + f"Added teacher logprobs for {len(teacher_logprobs)} sequences" + ) + except Exception as e: + logger.warning(f"Failed to fetch teacher logprobs: {e}") + await self.add_rollouts_for_wandb(group, item) if self.jsonl_writer is not None: From 3fdaff9bb4303a8848b37602cf4c788b7e1460d6 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 16 Feb 2026 17:18:01 -0500 Subject: [PATCH 02/23] Fix math_server_zero.py to support CLI OpenAI arguments Change ServerBaseline to APIServerConfig in config_init() so that --openai.base_url and other CLI arguments work for on-policy distillation. Co-authored-by: Cursor --- environments/math_server_zero.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 1432ab4d..0d6c140c 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -18,6 +18,7 @@ from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( + APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, @@ -119,7 +120,7 @@ class MathEnv(BaseEnv): def __init__( self, config: RSConfig, - server_configs: ServerBaseline, + server_configs: APIServerConfig | ServerBaseline, slurm=True, testing=False, ): @@ -137,7 +138,7 @@ class MathEnv(BaseEnv): self.iter = 0 @classmethod - def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: + def config_init(cls) -> Tuple[RSConfig, APIServerConfig]: env_config = RSConfig( tokenizer_name="Qwen/Qwen2.5-7B", group_size=16, @@ -152,10 +153,11 @@ class MathEnv(BaseEnv): eval_limit_ratio=0.1, max_num_workers_per_node=24, ) - server_configs = ServerBaseline( + server_configs = APIServerConfig( model_name="Qwen/Qwen2.5-7B", num_requests_for_eval=256, # since evaling only on one... server_type="vllm", + base_url="", # Override via CLI: --openai.base_url ) return env_config, server_configs From b492ac4fce84e7f9eea03d0a0329daa020738404 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 16 Feb 2026 17:39:37 -0500 Subject: [PATCH 03/23] on policy changes --- atroposlib/envs/base.py | 25 +++++++++++++++++-------- environments/math_server_zero.py | 8 +++----- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 81569f0d..75ed08e2 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1654,20 +1654,29 @@ class BaseEnv(ABC): cli_passed_flags, openai_full_prefix ) # CLI args yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) - if isinstance(default_server_configs, ServerBaseline) and ( + + # Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided + # This allows any environment to use --openai.* CLI args without modifying config_init + # Use a new variable to avoid UnboundLocalError from closure scoping + effective_server_configs = default_server_configs + if isinstance(effective_server_configs, ServerBaseline) and ( oai_cli_passed_args or yaml_oai_config ): - raise ValueError( - "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501 + # Convert ServerBaseline to APIServerConfig, preserving common fields + baseline_dict = effective_server_configs.model_dump() + effective_server_configs = APIServerConfig(**baseline_dict) + logger.info( + "Auto-converted ServerBaseline to APIServerConfig for CLI/YAML overrides" ) + if ( - isinstance(default_server_configs, list) - and len(default_server_configs) == 1 + isinstance(effective_server_configs, list) + and len(effective_server_configs) == 1 ): # can't use the same var name because it shadows the class variable and we get an error - default_openai_config_ = default_server_configs[0] + default_openai_config_ = effective_server_configs[0] else: - default_openai_config_ = default_server_configs + default_openai_config_ = effective_server_configs if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: yaml_oai_config = yaml_oai_config[0] if isinstance(default_openai_config_, APIServerConfig) and isinstance( @@ -1717,7 +1726,7 @@ class BaseEnv(ABC): # Determine the final server_configs, handling single, multiple servers, and overrides. openai_configs = resolve_openai_configs( - default_server_configs=default_server_configs, + default_server_configs=effective_server_configs, openai_config_dict=openai_config_dict, yaml_config=yaml_config, cli_passed_flags=cli_passed_flags, diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 0d6c140c..1432ab4d 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -18,7 +18,6 @@ from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( - APIServerConfig, BaseEnv, BaseEnvConfig, EvalHandlingEnum, @@ -120,7 +119,7 @@ class MathEnv(BaseEnv): def __init__( self, config: RSConfig, - server_configs: APIServerConfig | ServerBaseline, + server_configs: ServerBaseline, slurm=True, testing=False, ): @@ -138,7 +137,7 @@ class MathEnv(BaseEnv): self.iter = 0 @classmethod - def config_init(cls) -> Tuple[RSConfig, APIServerConfig]: + def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: env_config = RSConfig( tokenizer_name="Qwen/Qwen2.5-7B", group_size=16, @@ -153,11 +152,10 @@ class MathEnv(BaseEnv): eval_limit_ratio=0.1, max_num_workers_per_node=24, ) - server_configs = APIServerConfig( + server_configs = ServerBaseline( model_name="Qwen/Qwen2.5-7B", num_requests_for_eval=256, # since evaling only on one... server_type="vllm", - base_url="", # Override via CLI: --openai.base_url ) return env_config, server_configs From e81400757525b83dd80539a795c16bdc9897aabc Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 16 Feb 2026 21:05:57 -0500 Subject: [PATCH 04/23] base env debugging --- atroposlib/envs/base.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 75ed08e2..5f6052b5 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -361,7 +361,11 @@ class BaseEnv(ABC): Structure: [batch][position][top_k] = [token_id, logprob] Returns empty list if teacher_base_url is not configured. """ + logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences") + logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}") + if not self.config.teacher_base_url: + logger.warning("[TEACHER] No teacher_base_url configured, returning empty") return [] if top_k is None: @@ -371,6 +375,8 @@ class BaseEnv(ABC): api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "") model_name = self.config.teacher_model_name or "default" + logger.info(f"[TEACHER] Using model={model_name}, top_k={top_k}") + headers = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" @@ -380,6 +386,7 @@ class BaseEnv(ABC): try: async with aiohttp.ClientSession() as session: for i, tokens in enumerate(token_sequences): + logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens") # Decode tokens to text full_text = self.tokenizer.decode(tokens, skip_special_tokens=False) @@ -1124,8 +1131,11 @@ class BaseEnv(ABC): ] # Automatic on-policy distillation: fetch teacher logprobs if enabled + logger.info(f"[DISTILL DEBUG] distillation_enabled={self.config.distillation_enabled}, teacher_base_url={self.config.teacher_base_url}") if self.config.distillation_enabled and self.config.teacher_base_url: + logger.info(f"[DISTILL DEBUG] Distillation is enabled! Checking for existing logprobs...") if group.get("onpolicydistill_logprobs") is None: + logger.info(f"[DISTILL DEBUG] No existing logprobs, fetching from teacher...") try: teacher_logprobs = await self.get_teacher_logprobs( token_sequences=group["tokens"], @@ -1133,11 +1143,17 @@ class BaseEnv(ABC): ) if teacher_logprobs: group["onpolicydistill_logprobs"] = teacher_logprobs - logger.debug( - f"Added teacher logprobs for {len(teacher_logprobs)} sequences" + logger.info( + f"[DISTILL DEBUG] Added teacher logprobs for {len(teacher_logprobs)} sequences" ) + else: + logger.warning("[DISTILL DEBUG] get_teacher_logprobs returned empty!") except Exception as e: - logger.warning(f"Failed to fetch teacher logprobs: {e}") + logger.error(f"[DISTILL DEBUG] Failed to fetch teacher logprobs: {e}") + import traceback + logger.error(traceback.format_exc()) + else: + logger.debug(f"[DISTILL DEBUG] Distillation skipped - not enabled or no teacher URL") await self.add_rollouts_for_wandb(group, item) From ea2b388435fc96c3456a84064dc9feb8794605fd Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 16 Feb 2026 21:20:33 -0500 Subject: [PATCH 05/23] base env debugging --- atroposlib/envs/base.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 5f6052b5..ebf27eeb 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1671,6 +1671,12 @@ class BaseEnv(ABC): ) # CLI args yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) + # Debug logging for CLI args + print(f"[CLI DEBUG] cli_passed_flags = {cli_passed_flags}") + print(f"[CLI DEBUG] openai_full_prefix = {openai_full_prefix}") + print(f"[CLI DEBUG] oai_cli_passed_args = {oai_cli_passed_args}") + print(f"[CLI DEBUG] yaml_oai_config = {yaml_oai_config}") + # Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided # This allows any environment to use --openai.* CLI args without modifying config_init # Use a new variable to avoid UnboundLocalError from closure scoping @@ -1698,12 +1704,15 @@ class BaseEnv(ABC): if isinstance(default_openai_config_, APIServerConfig) and isinstance( yaml_oai_config, dict ): + print(f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}") openai_config_dict = merge_dicts( default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) yaml_oai_config, oai_cli_passed_args, ) + print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}") else: + print(f"[CLI DEBUG] Not merging: default_openai_config_ type={type(default_openai_config_)}, yaml_oai_config type={type(yaml_oai_config)}") openai_config_dict = {} # 3. Server Manager Configuration (slurm, testing - not namespaced) From fb23014dcc4e00e90fb34a7af709487b9424e560 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 16 Feb 2026 21:23:54 -0500 Subject: [PATCH 06/23] base env debugging --- atroposlib/envs/server_handling/openai_server.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index fce40f80..d995807c 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -165,12 +165,17 @@ def resolve_openai_configs( """ from atroposlib.envs.server_handling.server_manager import ServerBaseline + print(f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}") + print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}") + openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" openai_yaml_config = yaml_config.get(OPENAI_NAMESPACE, None) openai_cli_config = { k: v for k, v in cli_passed_flags.items() if k.startswith(openai_full_prefix) } + print(f"[RESOLVE DEBUG] openai_cli_config = {openai_cli_config}") + is_multi_server_yaml = ( isinstance(openai_yaml_config, list) and len(openai_yaml_config) >= 2 ) @@ -180,6 +185,9 @@ def resolve_openai_configs( and len(default_server_configs) >= 2 ) + print(f"[RESOLVE DEBUG] is_multi_server_yaml={is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}") + print(f"[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = {isinstance(default_server_configs, ServerBaseline)}") + if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config: raise FailedExecutionException( message=f"CLI overrides for OpenAI settings (--{openai_full_prefix}*) are not supported " @@ -189,6 +197,7 @@ def resolve_openai_configs( ) if is_multi_server_yaml: + print("[RESOLVE DEBUG] Taking multi-server YAML path") logger.info( f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'." ) @@ -199,12 +208,15 @@ def resolve_openai_configs( f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" ) from e elif isinstance(default_server_configs, ServerBaseline): + print("[RESOLVE DEBUG] Taking ServerBaseline path") logger.info("Using ServerBaseline configuration.") server_configs = default_server_configs elif is_multi_server_default: + print("[RESOLVE DEBUG] Taking multi-server default path") logger.info("Using default multi-server configuration (length >= 2).") server_configs = default_server_configs else: + print("[RESOLVE DEBUG] Taking single server merged path") logger.info( "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." ) @@ -216,9 +228,12 @@ def resolve_openai_configs( f"Merged Dict: {openai_config_dict}" ) from e + print(f"[RESOLVE DEBUG] final_openai_config = {final_openai_config}") if isinstance(default_server_configs, APIServerConfig): + print("[RESOLVE DEBUG] Returning final_openai_config directly") server_configs = final_openai_config elif isinstance(default_server_configs, list): + print("[RESOLVE DEBUG] Returning [final_openai_config]") server_configs = [final_openai_config] else: logger.warning( @@ -227,4 +242,5 @@ def resolve_openai_configs( ) server_configs = [final_openai_config] + print(f"[RESOLVE DEBUG] Returning server_configs = {server_configs}") return server_configs From 0510ca9b724b79dabe47ba9985bbda8f22b56995 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Mon, 16 Feb 2026 21:26:44 -0500 Subject: [PATCH 07/23] found bug --- atroposlib/envs/server_handling/openai_server.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index d995807c..bbf0c15b 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -207,7 +207,20 @@ def resolve_openai_configs( raise FailedExecutionException( f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" ) from e + elif isinstance(default_server_configs, APIServerConfig): + # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline + print("[RESOLVE DEBUG] Taking APIServerConfig merged path") + logger.info("Using single OpenAI server configuration based on merged settings (default/YAML/CLI).") + try: + final_openai_config = APIServerConfig(**openai_config_dict) + except Exception as e: + raise FailedExecutionException( + f"Error creating final OpenAI configuration from merged settings: {e}\n" + f"Merged Dict: {openai_config_dict}" + ) from e + server_configs = final_openai_config elif isinstance(default_server_configs, ServerBaseline): + # Pure ServerBaseline (not APIServerConfig) - no CLI overrides possible print("[RESOLVE DEBUG] Taking ServerBaseline path") logger.info("Using ServerBaseline configuration.") server_configs = default_server_configs @@ -216,7 +229,7 @@ def resolve_openai_configs( logger.info("Using default multi-server configuration (length >= 2).") server_configs = default_server_configs else: - print("[RESOLVE DEBUG] Taking single server merged path") + print("[RESOLVE DEBUG] Taking single server merged path (fallback)") logger.info( "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." ) From c89854a350b276b9dd3bef9d9d996042323f6da1 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 17 Feb 2026 08:15:07 -0500 Subject: [PATCH 08/23] debug changes --- atroposlib/envs/base.py | 35 ++++++-- environments/math_server_zero.py | 138 +++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 7 deletions(-) diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index ebf27eeb..a3ac2117 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -209,8 +209,12 @@ class BaseEnvConfig(BaseModel): "no thinking prompt is injected. Use HERMES_REASONING_PROMPT from " "eval_helpers for the standard Hermes reasoning prompt.", ) - # On-policy distillation settings + distillation_enabled: bool = Field( + default=False, + description="Enable on-policy distillation. When True, automatically fetches teacher logprobs " + "after scoring and includes them in data sent to trainer.", + ) teacher_base_url: Optional[str] = Field( default=None, description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API " @@ -226,14 +230,9 @@ class BaseEnvConfig(BaseModel): description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.", ) teacher_top_k: int = Field( - default=10, + default=20, description="Number of top logprobs to fetch from teacher model per position.", ) - distillation_enabled: bool = Field( - default=False, - description="Enable on-policy distillation. When True, automatically fetches teacher logprobs " - "after scoring and includes them in data sent to trainer.", - ) class BaseEnv(ABC): @@ -1164,6 +1163,28 @@ class BaseEnv(ABC): valid_groups.append(group) if valid_groups and do_send_to_api: + # On-policy distillation: fetch teacher logprobs if enabled + if self.config.distillation_enabled and self.config.teacher_base_url: + logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups") + for group in valid_groups: + if group.get("onpolicydistill_logprobs") is None: + try: + teacher_logprobs = await self.get_teacher_logprobs( + token_sequences=group["tokens"], + messages_list=group.get("messages"), + ) + if teacher_logprobs: + group["onpolicydistill_logprobs"] = teacher_logprobs + logger.info(f"[DISTILL] Added teacher logprobs for {len(teacher_logprobs)} sequences") + else: + logger.warning("[DISTILL] get_teacher_logprobs returned empty") + except Exception as e: + logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}") + import traceback + logger.error(traceback.format_exc()) + else: + logger.debug(f"[DISTILL] Skipped - enabled={self.config.distillation_enabled}, url={self.config.teacher_base_url}") + data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]] # send single or list of scored data groups if not original_was_list and len(valid_groups) == 1: diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 1432ab4d..932df9dc 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -6,11 +6,17 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero import asyncio import random import re +import logging from concurrent.futures import ProcessPoolExecutor from typing import Dict, List, Optional, Tuple +import aiohttp import wandb from datasets import load_dataset + +# Set up logging for debug +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify from math_verify.errors import TimeoutException @@ -135,6 +141,15 @@ class MathEnv(BaseEnv): self.normal_rollouts = list() self.pass_at_groupsize = list() self.iter = 0 + + # Debug: Print distillation config + print("=" * 60) + print("[MATH_DEBUG] DISTILLATION CONFIGURATION:") + print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}") + print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}") + print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}") + print(f"[MATH_DEBUG] teacher_top_logprobs = {getattr(config, 'teacher_top_logprobs', 'N/A')}") + print("=" * 60) @classmethod def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: @@ -252,7 +267,85 @@ class MathEnv(BaseEnv): name, ) ) + + # Debug: Test teacher connectivity if distillation is enabled + if self.config.distillation_enabled and self.config.teacher_base_url: + await self._test_teacher_connectivity() + return + + async def _test_teacher_connectivity(self): + """Test if the teacher model API is reachable.""" + print("=" * 60) + print("[MATH_DEBUG] TESTING TEACHER CONNECTIVITY...") + print(f"[MATH_DEBUG] Teacher URL: {self.config.teacher_base_url}") + print(f"[MATH_DEBUG] Teacher Model: {getattr(self.config, 'teacher_model_name', 'default')}") + + try: + async with aiohttp.ClientSession() as session: + # Test 1: Health check + health_url = self.config.teacher_base_url.replace("/v1", "") + "/health" + print(f"[MATH_DEBUG] Testing health endpoint: {health_url}") + try: + async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as resp: + print(f"[MATH_DEBUG] Health check status: {resp.status}") + if resp.status == 200: + print("[MATH_DEBUG] ✓ Teacher health check PASSED") + else: + print(f"[MATH_DEBUG] ✗ Teacher health check FAILED: {await resp.text()}") + except Exception as e: + print(f"[MATH_DEBUG] ✗ Teacher health check ERROR: {e}") + + # Test 2: Models endpoint + models_url = f"{self.config.teacher_base_url}/models" + print(f"[MATH_DEBUG] Testing models endpoint: {models_url}") + try: + async with session.get(models_url, timeout=aiohttp.ClientTimeout(total=10)) as resp: + print(f"[MATH_DEBUG] Models endpoint status: {resp.status}") + if resp.status == 200: + data = await resp.json() + models = [m.get("id", m) for m in data.get("data", [])] + print(f"[MATH_DEBUG] ✓ Available models: {models}") + else: + print(f"[MATH_DEBUG] ✗ Models endpoint FAILED: {await resp.text()}") + except Exception as e: + print(f"[MATH_DEBUG] ✗ Models endpoint ERROR: {e}") + + # Test 3: Simple completion test + completions_url = f"{self.config.teacher_base_url}/completions" + teacher_model = getattr(self.config, 'teacher_model_name', 'default') + test_payload = { + "model": teacher_model, + "prompt": "Hello", + "max_tokens": 5, + "logprobs": 5, + "echo": True, + } + print(f"[MATH_DEBUG] Testing completions endpoint: {completions_url}") + print(f"[MATH_DEBUG] Test payload: {test_payload}") + try: + async with session.post( + completions_url, + json=test_payload, + headers={"Content-Type": "application/json"}, + timeout=aiohttp.ClientTimeout(total=30), + ) as resp: + print(f"[MATH_DEBUG] Completions status: {resp.status}") + resp_text = await resp.text() + if resp.status == 200: + print(f"[MATH_DEBUG] ✓ Teacher completions WORKING!") + print(f"[MATH_DEBUG] Response preview: {resp_text[:500]}") + else: + print(f"[MATH_DEBUG] ✗ Teacher completions FAILED: {resp_text[:500]}") + except Exception as e: + print(f"[MATH_DEBUG] ✗ Teacher completions ERROR: {e}") + + except Exception as e: + print(f"[MATH_DEBUG] ✗ Teacher connectivity test FAILED: {e}") + import traceback + traceback.print_exc() + + print("=" * 60) async def rollout_and_score_eval(self, question, answer, subset): async with self.server.managed_server(tokenizer=self.tokenizer) as managed: @@ -482,7 +575,52 @@ class MathEnv(BaseEnv): and (not scores["overrides"][i].get("set_advantage_to_zero", False)) ] ) + + # Debug: Log scored group creation + print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences") + print(f"[MATH_DEBUG] Scores: {scores['scores']}") + print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}") + print(f"[MATH_DEBUG] Has onpolicydistill_logprobs: {'onpolicydistill_logprobs' in scores}") + return scores + + async def handle_send_to_api( + self, + scored_data, + item=None, + do_send_to_api: bool = True, + abort_on_any_max_length_exceeded: bool = True, + ): + """Override to add debugging for distillation.""" + print(f"[MATH_DEBUG] handle_send_to_api called") + print(f"[MATH_DEBUG] distillation_enabled: {self.config.distillation_enabled}") + print(f"[MATH_DEBUG] teacher_base_url: {self.config.teacher_base_url}") + + if isinstance(scored_data, list): + for i, group in enumerate(scored_data): + if group: + has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None + print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}") + elif scored_data: + has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None + print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}") + + # Call parent implementation which does the actual distillation fetch + result = await super().handle_send_to_api( + scored_data, item, do_send_to_api, abort_on_any_max_length_exceeded + ) + + # Debug: Check if distillation was added after parent call + if isinstance(scored_data, list): + for i, group in enumerate(scored_data): + if group: + has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None + print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}") + elif scored_data: + has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None + print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}") + + return result async def get_next_item(self): while True: From 79e392c4468c310a503f73d608afd3799617f9a0 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 20 Feb 2026 00:32:47 -0500 Subject: [PATCH 09/23] post merge changes --- atroposlib/api/server.py | 10 +- atroposlib/envs/base.py | 249 +++++++++++++++++++++---------- environments/math_server_zero.py | 29 +++- 3 files changed, 200 insertions(+), 88 deletions(-) diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index f6b1d4d1..978a6f25 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -145,9 +145,10 @@ class ScoredData(BaseModel): group_overrides: Optional[dict] = None images: Optional[Any] = None env_id: Optional[int] = None # ID of the environment that generated this data - # On-policy distillation: top-K logprobs from teacher model - # Structure: [sequence][position][top_k] = [token_id, logprob] - onpolicydistill_logprobs: Optional[List[List[List[List]]]] = None + # On-policy distillation (new format): parallel token ids + logprobs. + # Shape for both: [sequence][position][top_k] + distill_token_ids: Optional[List[List[List[int]]]] = None + distill_logprobs: Optional[List[List[List[float]]]] = None @field_validator("messages", mode="before") @classmethod @@ -185,7 +186,8 @@ def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]: "group_overrides": scored_data.group_overrides, "images": scored_data.images, "env_id": scored_data.env_id, - "onpolicydistill_logprobs": scored_data.onpolicydistill_logprobs, + "distill_token_ids": scored_data.distill_token_ids, + "distill_logprobs": scored_data.distill_logprobs, } diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index a3ac2117..25054e0c 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -66,9 +66,10 @@ class ScoredDataGroup(TypedDict): group_overrides: Optional[Dict] overrides: Optional[List[Dict]] images: Optional[Any] - # On-policy distillation: top-K logprobs from teacher model - # Structure: List[List[List[List[Union[int, float]]]]] = [sequence][position][top_k] = [token_id, logprob] - onpolicydistill_logprobs: Optional[List[List[List[List]]]] + # On-policy distillation (new format): parallel token ids + logprobs. + # distill_token_ids/distill_logprobs are [sequence][position][top_k] + distill_token_ids: Optional[List[List[List[int]]]] + distill_logprobs: Optional[List[List[List[float]]]] class ScoredDataItem(TypedDict): @@ -81,8 +82,9 @@ class ScoredDataItem(TypedDict): group_overrides: Optional[Dict] overrides: Optional[Dict] images: Optional[Any] - # On-policy distillation: top-K logprobs from teacher model per position - onpolicydistill_logprobs: Optional[List[List[List]]] + # On-policy distillation (new format): parallel token ids + logprobs per position. + distill_token_ids: Optional[List[List[int]]] + distill_logprobs: Optional[List[List[float]]] class EvalHandlingEnum(Enum): @@ -233,6 +235,18 @@ class BaseEnvConfig(BaseModel): default=20, description="Number of top logprobs to fetch from teacher model per position.", ) + teacher_prefix_text: Optional[str] = Field( + default=None, + description="Optional text prefix prepended to teacher scoring prompt. " + "Useful for behavior steering. Prefix token positions are trimmed out " + "before sending distillation arrays to the trainer, preserving alignment.", + ) + teacher_system_prompt: Optional[str] = Field( + default=None, + description="Optional teacher system prompt. For completion-style teacher APIs, " + "this is converted to a textual prefix. For chat fallback, this is injected " + "as a leading system message.", + ) class BaseEnv(ABC): @@ -343,7 +357,7 @@ class BaseEnv(ABC): token_sequences: List[List[int]], messages_list: Optional[List[List[Dict]]] = None, top_k: Optional[int] = None, - ) -> List[List[List[List]]]: + ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: """ Fetch top-K logprobs from teacher model for given sequences. @@ -356,16 +370,16 @@ class BaseEnv(ABC): top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k) Returns: - List of top-K logprobs per position per sequence - Structure: [batch][position][top_k] = [token_id, logprob] - Returns empty list if teacher_base_url is not configured. + Tuple of (distill_token_ids, distill_logprobs), both shaped as: + [batch][position][top_k]. + Returns ([], []) if teacher_base_url is not configured. """ logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences") logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}") if not self.config.teacher_base_url: logger.warning("[TEACHER] No teacher_base_url configured, returning empty") - return [] + return [], [] if top_k is None: top_k = self.config.teacher_top_k @@ -380,14 +394,29 @@ class BaseEnv(ABC): if api_key: headers["Authorization"] = f"Bearer {api_key}" - results = [] + token_id_results: List[List[List[int]]] = [] + logprob_results: List[List[List[float]]] = [] try: async with aiohttp.ClientSession() as session: for i, tokens in enumerate(token_sequences): logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens") - # Decode tokens to text - full_text = self.tokenizer.decode(tokens, skip_special_tokens=False) + # Decode original sequence and optionally prepend teacher steering text. + base_text = self.tokenizer.decode(tokens, skip_special_tokens=False) + steering_prefix = "" + if self.config.teacher_system_prompt: + steering_prefix += ( + "System instruction:\n" + f"{self.config.teacher_system_prompt.strip()}\n\n" + ) + if self.config.teacher_prefix_text: + steering_prefix += self.config.teacher_prefix_text + full_text = steering_prefix + base_text + prefix_token_len = ( + len(self.tokenizer.encode(steering_prefix, add_special_tokens=False)) + if steering_prefix + else 0 + ) # Try vLLM-style completions first (supports prompt_logprobs) # This is most efficient as it doesn't generate new tokens @@ -409,9 +438,18 @@ class BaseEnv(ABC): ) as response: if response.status == 200: data = await response.json() - seq_result = self._parse_completion_logprobs(data, top_k) - if seq_result: - results.append(seq_result) + seq_token_ids, seq_logprobs = self._parse_completion_logprobs( + data, top_k + ) + if seq_token_ids and seq_logprobs: + aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( + seq_token_ids, + seq_logprobs, + target_token_len=len(tokens), + prefix_token_len=prefix_token_len, + ) + token_id_results.append(aligned_ids) + logprob_results.append(aligned_lps) continue except Exception: pass # Fall through to chat completions @@ -419,10 +457,25 @@ class BaseEnv(ABC): # Fallback: Use chat/completions with logprobs (OpenAI style) # This requires messages format if messages_list and i < len(messages_list): - messages = messages_list[i] + messages = list(messages_list[i]) + if self.config.teacher_system_prompt: + messages = [ + { + "role": "system", + "content": self.config.teacher_system_prompt, + } + ] + messages else: # Convert text to simple message format - messages = [{"role": "user", "content": full_text}] + messages = [] + if self.config.teacher_system_prompt: + messages.append( + { + "role": "system", + "content": self.config.teacher_system_prompt, + } + ) + messages.append({"role": "user", "content": full_text}) chat_request = { "model": model_name, @@ -442,40 +495,88 @@ class BaseEnv(ABC): ) as response: if response.status == 200: data = await response.json() - seq_result = self._parse_chat_logprobs(data, top_k) - results.append(seq_result) + seq_token_ids, seq_logprobs = self._parse_chat_logprobs( + data, top_k + ) + # Chat fallback logprobs are for generated tokens, not prompt tokens. + # To keep alignment correct for distillation, return empty per-position rows. + if seq_token_ids and len(seq_token_ids) >= len(tokens): + aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( + seq_token_ids, + seq_logprobs, + target_token_len=len(tokens), + prefix_token_len=0, + ) + else: + aligned_ids = [[] for _ in range(len(tokens))] + aligned_lps = [[] for _ in range(len(tokens))] + token_id_results.append(aligned_ids) + logprob_results.append(aligned_lps) else: logger.warning(f"Teacher API returned {response.status}") - results.append([]) + token_id_results.append([[] for _ in range(len(tokens))]) + logprob_results.append([[] for _ in range(len(tokens))]) except Exception as e: logger.warning(f"Teacher chat request failed: {e}") - results.append([]) + token_id_results.append([[] for _ in range(len(tokens))]) + logprob_results.append([[] for _ in range(len(tokens))]) - return results + return token_id_results, logprob_results except Exception as e: logger.error(f"Error fetching teacher logprobs: {e}") - return [] + return [], [] + + def _align_teacher_topk_to_tokens( + self, + seq_token_ids: List[List[int]], + seq_logprobs: List[List[float]], + target_token_len: int, + prefix_token_len: int = 0, + ) -> Tuple[List[List[int]], List[List[float]]]: + """ + Trim teacher prefix positions and enforce exact length alignment with source tokens. + """ + n = min(len(seq_token_ids), len(seq_logprobs)) + aligned_ids = list(seq_token_ids[:n]) + aligned_lps = list(seq_logprobs[:n]) + + if prefix_token_len > 0: + aligned_ids = aligned_ids[prefix_token_len:] + aligned_lps = aligned_lps[prefix_token_len:] + + # Truncate any tail token (e.g., generated token when max_tokens>0 with echo=True) + aligned_ids = aligned_ids[:target_token_len] + aligned_lps = aligned_lps[:target_token_len] + + # Pad missing positions to preserve strict [seq][position][top_k] shape. + if len(aligned_ids) < target_token_len: + pad_count = target_token_len - len(aligned_ids) + aligned_ids.extend([[] for _ in range(pad_count)]) + aligned_lps.extend([[] for _ in range(pad_count)]) + + return aligned_ids, aligned_lps def _parse_completion_logprobs( self, data: Dict, top_k: int - ) -> List[List[List]]: - """Parse logprobs from vLLM-style completion response.""" + ) -> Tuple[List[List[int]], List[List[float]]]: + """Parse token ids + logprobs from vLLM-style completion response.""" try: choice = data.get("choices", [{}])[0] logprobs_data = choice.get("logprobs", {}) # vLLM returns top_logprobs as list of dicts top_logprobs = logprobs_data.get("top_logprobs", []) - tokens = logprobs_data.get("tokens", []) if not top_logprobs: - return [] - - seq_result = [] + return [], [] + + seq_token_ids: List[List[int]] = [] + seq_logprobs: List[List[float]] = [] for pos_logprobs in top_logprobs: if pos_logprobs is None: - seq_result.append([]) + seq_token_ids.append([]) + seq_logprobs.append([]) elif isinstance(pos_logprobs, dict): # Format: {token_str: logprob, ...} sorted_items = sorted( @@ -483,51 +584,59 @@ class BaseEnv(ABC): key=lambda x: x[1], reverse=True )[:top_k] - pos_result = [] + pos_ids: List[int] = [] + pos_lps: List[float] = [] for token_str, logprob in sorted_items: # Convert token string to ID token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) if token_ids: - pos_result.append([token_ids[0], float(logprob)]) - seq_result.append(pos_result) + pos_ids.append(int(token_ids[0])) + pos_lps.append(float(logprob)) + seq_token_ids.append(pos_ids) + seq_logprobs.append(pos_lps) else: - seq_result.append([]) - - return seq_result + seq_token_ids.append([]) + seq_logprobs.append([]) + + return seq_token_ids, seq_logprobs except Exception as e: logger.warning(f"Error parsing completion logprobs: {e}") - return [] + return [], [] def _parse_chat_logprobs( self, data: Dict, top_k: int - ) -> List[List[List]]: - """Parse logprobs from OpenAI-style chat completion response.""" + ) -> Tuple[List[List[int]], List[List[float]]]: + """Parse token ids + logprobs from OpenAI-style chat completion response.""" try: choice = data.get("choices", [{}])[0] logprobs_data = choice.get("logprobs", {}) if not logprobs_data: - return [] + return [], [] content = logprobs_data.get("content", []) - seq_result = [] + seq_token_ids: List[List[int]] = [] + seq_logprobs: List[List[float]] = [] for token_data in content: top_logprobs = token_data.get("top_logprobs", []) - pos_result = [] + pos_ids: List[int] = [] + pos_lps: List[float] = [] for item in top_logprobs[:top_k]: token_str = item.get("token", "") logprob = item.get("logprob", 0.0) # Convert token string to ID token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) if token_ids: - pos_result.append([token_ids[0], float(logprob)]) - seq_result.append(pos_result) - - return seq_result + pos_ids.append(int(token_ids[0])) + pos_lps.append(float(logprob)) + seq_token_ids.append(pos_ids) + seq_logprobs.append(pos_lps) + + return seq_token_ids, seq_logprobs except Exception as e: logger.warning(f"Error parsing chat logprobs: {e}") - return [] + return [], [] @classmethod def config_init( @@ -1108,6 +1217,8 @@ class BaseEnv(ABC): group.setdefault("ref_logprobs", None) group.setdefault("overrides", None) group.setdefault("group_overrides", None) + group.setdefault("distill_token_ids", None) + group.setdefault("distill_logprobs", None) for mask in group["masks"]: self.completion_lengths.append(sum(m != -100 for m in mask)) @@ -1129,31 +1240,6 @@ class BaseEnv(ABC): for i in range(len(group["tokens"])) ] - # Automatic on-policy distillation: fetch teacher logprobs if enabled - logger.info(f"[DISTILL DEBUG] distillation_enabled={self.config.distillation_enabled}, teacher_base_url={self.config.teacher_base_url}") - if self.config.distillation_enabled and self.config.teacher_base_url: - logger.info(f"[DISTILL DEBUG] Distillation is enabled! Checking for existing logprobs...") - if group.get("onpolicydistill_logprobs") is None: - logger.info(f"[DISTILL DEBUG] No existing logprobs, fetching from teacher...") - try: - teacher_logprobs = await self.get_teacher_logprobs( - token_sequences=group["tokens"], - messages_list=group.get("messages"), - ) - if teacher_logprobs: - group["onpolicydistill_logprobs"] = teacher_logprobs - logger.info( - f"[DISTILL DEBUG] Added teacher logprobs for {len(teacher_logprobs)} sequences" - ) - else: - logger.warning("[DISTILL DEBUG] get_teacher_logprobs returned empty!") - except Exception as e: - logger.error(f"[DISTILL DEBUG] Failed to fetch teacher logprobs: {e}") - import traceback - logger.error(traceback.format_exc()) - else: - logger.debug(f"[DISTILL DEBUG] Distillation skipped - not enabled or no teacher URL") - await self.add_rollouts_for_wandb(group, item) if self.jsonl_writer is not None: @@ -1167,15 +1253,22 @@ class BaseEnv(ABC): if self.config.distillation_enabled and self.config.teacher_base_url: logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups") for group in valid_groups: - if group.get("onpolicydistill_logprobs") is None: + has_new_format = ( + group.get("distill_token_ids") is not None + and group.get("distill_logprobs") is not None + ) + if not has_new_format: try: - teacher_logprobs = await self.get_teacher_logprobs( + teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs( token_sequences=group["tokens"], messages_list=group.get("messages"), ) - if teacher_logprobs: - group["onpolicydistill_logprobs"] = teacher_logprobs - logger.info(f"[DISTILL] Added teacher logprobs for {len(teacher_logprobs)} sequences") + if teacher_token_ids and teacher_logprobs: + group["distill_token_ids"] = teacher_token_ids + group["distill_logprobs"] = teacher_logprobs + logger.info( + f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences" + ) else: logger.warning("[DISTILL] get_teacher_logprobs returned empty") except Exception as e: diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 932df9dc..93856ea3 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -148,7 +148,9 @@ class MathEnv(BaseEnv): print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}") print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}") print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}") - print(f"[MATH_DEBUG] teacher_top_logprobs = {getattr(config, 'teacher_top_logprobs', 'N/A')}") + print(f"[MATH_DEBUG] teacher_top_k = {getattr(config, 'teacher_top_k', 'N/A')}") + print(f"[MATH_DEBUG] teacher_prefix_text set = {bool(getattr(config, 'teacher_prefix_text', None))}") + print(f"[MATH_DEBUG] teacher_system_prompt set = {bool(getattr(config, 'teacher_system_prompt', None))}") print("=" * 60) @classmethod @@ -580,7 +582,10 @@ class MathEnv(BaseEnv): print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences") print(f"[MATH_DEBUG] Scores: {scores['scores']}") print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}") - print(f"[MATH_DEBUG] Has onpolicydistill_logprobs: {'onpolicydistill_logprobs' in scores}") + has_new_distill = ( + "distill_token_ids" in scores and "distill_logprobs" in scores + ) + print(f"[MATH_DEBUG] Has distill arrays: {has_new_distill}") return scores @@ -599,10 +604,16 @@ class MathEnv(BaseEnv): if isinstance(scored_data, list): for i, group in enumerate(scored_data): if group: - has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None + has_distill = ( + group.get("distill_token_ids") is not None + and group.get("distill_logprobs") is not None + ) print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}") elif scored_data: - has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None + has_distill = ( + scored_data.get("distill_token_ids") is not None + and scored_data.get("distill_logprobs") is not None + ) print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}") # Call parent implementation which does the actual distillation fetch @@ -614,10 +625,16 @@ class MathEnv(BaseEnv): if isinstance(scored_data, list): for i, group in enumerate(scored_data): if group: - has_distill = 'onpolicydistill_logprobs' in group and group.get('onpolicydistill_logprobs') is not None + has_distill = ( + group.get("distill_token_ids") is not None + and group.get("distill_logprobs") is not None + ) print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}") elif scored_data: - has_distill = 'onpolicydistill_logprobs' in scored_data and scored_data.get('onpolicydistill_logprobs') is not None + has_distill = ( + scored_data.get("distill_token_ids") is not None + and scored_data.get("distill_logprobs") is not None + ) print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}") return result From 1c90fc71b0a0d79e6041ec858a3bbf6733044a31 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 20 Feb 2026 00:35:29 -0500 Subject: [PATCH 10/23] on policy clean up --- README.md | 127 ++++++++++++++ atroposlib/envs/base.py | 12 +- .../envs/server_handling/managed_server.py | 18 ++ .../envs/server_handling/openai_server.py | 10 +- .../envs/server_handling/vllm_server.py | 25 +++ environments/gsm8k_server.py | 77 ++++++++- environments/math_server_zero.py | 161 +----------------- 7 files changed, 262 insertions(+), 168 deletions(-) diff --git a/README.md b/README.md index b7ed312e..ddb2bac7 100644 --- a/README.md +++ b/README.md @@ -256,6 +256,133 @@ Atropos repo contains an example trainer that should primarily be used as a refe To use the example trainer, see this page: [training example guide](example_trainer/README.md) +## On-Policy Distillation (Environment + API Flow) + +Atropos supports on-policy distillation by fetching teacher top-k token distributions for the same trajectories that are used for RL, then attaching that teacher data to the batches consumed by your trainer. + +### How the flow works + +1. **Student rollouts are generated by the environment** + - Example: `environments/gsm8k_server.py` samples `group_size` completions from your student inference server. +2. **Environment scores and validates the group** + - Normal Atropos filtering still applies (group structure, optional score-equality checks, max token checks, etc.). +3. **Teacher logprobs are fetched in `BaseEnv.handle_send_to_api`** + - If `distillation_enabled=true` and `teacher_base_url` is set, Atropos calls the teacher endpoint and builds: + - `distill_token_ids` with shape `[sequence][position][top_k]` + - `distill_logprobs` with shape `[sequence][position][top_k]` +4. **Distillation arrays are attached to each scored group** + - Added as `group["distill_token_ids"]` and `group["distill_logprobs"]`. +5. **Atropos API stores and serves these fields unchanged** + - `/scored_data` and `/batch` include the distillation arrays. +6. **Trainer consumes both RL and distillation signals** + - Example trainer computes GRPO + distillation loss from the same batch. + +### Configuration knobs in environments + +Distillation is configured in `BaseEnvConfig` and available via CLI under `--env.*`: + +- `--env.distillation_enabled true` +- `--env.teacher_base_url http://localhost:8003/v1` +- `--env.teacher_model_name ` +- `--env.teacher_api_key ` (or `TEACHER_API_KEY`) +- `--env.teacher_top_k 20` +- Optional steering controls: + - `--env.teacher_prefix_text "..."` + - `--env.teacher_system_prompt "..."` + +### Self-distillation vs cross-model distillation + +Both setups are supported: + +- **Self-distillation (same model family for teacher and student)** + Point `teacher_base_url` to a server running the same model (or equivalent checkpoint family) as the student. This is the most stable setup for token-level alignment. + +- **Cross-model distillation (different teacher and student models)** + Also supported, but tokenization compatibility becomes more important. If token vocabularies/template behavior differ significantly, alignment quality may degrade. + +In practice, self-distillation is usually easiest to bring up first, then cross-model can be layered in once your pipeline is stable. + +### Tokenization and alignment details + +Atropos handles tokenization in two places: + +1. **Student rollout path (`server_type=vllm`)** + - The `/generate` request is built via the vLLM server handler and uses the server-side tokenizer configured by: + - `--openai.tokenizer_name` (or falls back to `--openai.model_name`) + - Recommendation: set `--openai.tokenizer_name` explicitly to match the student serving model. + +2. **Teacher top-k parsing path** + - Teacher responses are parsed into token ids/logprobs in `BaseEnv.get_teacher_logprobs`. + - The parser maps teacher token strings into ids using the environment tokenizer (`self.tokenizer`) and then aligns to student sequence length. + +Because distillation is token-position based, keeping tokenizer families compatible is strongly recommended, especially for cross-model distillation. + +### Minimal bring-up example + +Run each command in a separate terminal. + +1. **Start Atropos API** +```bash +run-api --port 8002 +``` + +2. **Start teacher server (OpenAI-compatible endpoint)** +```bash +python -m vllm.entrypoints.openai.api_server \ + --model "$TEACHER_MODEL" \ + --host 0.0.0.0 \ + --port 8003 +``` + +3. **Start student server for environments (`/generate` endpoint)** +```bash +python -m example_trainer.vllm_api_server \ + --model "$STUDENT_MODEL" \ + --port 9001 +``` + +4. **Start environment with distillation enabled** +```bash +python environments/gsm8k_server.py serve \ + --env.rollout_server_url "http://localhost:8002" \ + --env.distillation_enabled true \ + --env.teacher_base_url "http://localhost:8003/v1" \ + --env.teacher_model_name "$TEACHER_MODEL" \ + --env.teacher_top_k 20 \ + --openai.server_type vllm \ + --openai.base_url "http://localhost:9001/v1" \ + --openai.model_name "$STUDENT_MODEL" \ + --openai.tokenizer_name "$STUDENT_MODEL" +``` + +5. **Start trainer with distillation flags** +```bash +python -m example_trainer.grpo \ + --atropos-url "http://localhost:8002" \ + --distillation-enabled \ + --distillation-coef 0.1 \ + --distillation-loss-type kl \ + --distillation-temperature 1.0 +``` + +### Verification checklist + +- Environment logs show distillation fetch: + - `[DISTILL] Fetching teacher logprobs ...` + - `[DISTILL] Added teacher distill arrays ...` +- Teacher logs show completion/chat requests from the environment. +- API contains distill fields in latest example: +```bash +curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!=null), has_lps:(.distill_logprobs!=null)}' +``` +- Trainer logs report distillation metrics (example trainer): + - `Distill: loss=...` + +### Important notes + +- For `server_type=vllm`, the environment expects a server exposing `/generate` (the custom server in `example_trainer/vllm_api_server.py`), not only `/v1/chat/completions`. +- Prefer explicitly setting `--openai.tokenizer_name` to your student tokenizer to avoid prompt token-ID mismatch. + --- ## Testing and Debugging Tools diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 25054e0c..551eb5b0 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1276,7 +1276,11 @@ class BaseEnv(ABC): import traceback logger.error(traceback.format_exc()) else: - logger.debug(f"[DISTILL] Skipped - enabled={self.config.distillation_enabled}, url={self.config.teacher_base_url}") + logger.debug( + "[DISTILL] Skipped - enabled=%s, url=%s", + self.config.distillation_enabled, + self.config.teacher_base_url, + ) data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]] # send single or list of scored data groups @@ -1826,7 +1830,11 @@ class BaseEnv(ABC): ) print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}") else: - print(f"[CLI DEBUG] Not merging: default_openai_config_ type={type(default_openai_config_)}, yaml_oai_config type={type(yaml_oai_config)}") + print( + "[CLI DEBUG] Not merging: default_openai_config_ " + f"type={type(default_openai_config_)}, " + f"yaml_oai_config type={type(yaml_oai_config)}" + ) openai_config_dict = {} # 3. Server Manager Configuration (slurm, testing - not namespaced) diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 0918c325..cb14d210 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -7,6 +7,7 @@ This wrapper maintains a tree structure of sequences, where: - Branching occurs organically from different contexts and n > 1 completions """ +import os import time import uuid import warnings @@ -131,6 +132,10 @@ class ManagedServer: # Fallback for tokenizers without chat template return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + def _debug_requests_enabled(self) -> bool: + """Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1.""" + return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1" + def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]: """ Find a node that this input extends (default mode). @@ -284,6 +289,19 @@ class ManagedServer: completion_kwargs = kwargs.copy() completion_kwargs["prompt"] = prompt completion_kwargs.pop("messages", None) + if self._debug_requests_enabled(): + msg_count = len(messages) + prompt_preview = prompt.replace("\n", "\\n")[:600] + print( + f"[ATROPOS_REQ_DEBUG] chat_completion messages={msg_count} " + f"n={completion_kwargs.get('n')} max_tokens={completion_kwargs.get('max_tokens')} " + f"temperature={completion_kwargs.get('temperature')}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}", + flush=True, + ) # Set model name if not provided if "model" not in completion_kwargs: diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index bbf0c15b..f84558c4 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -185,8 +185,14 @@ def resolve_openai_configs( and len(default_server_configs) >= 2 ) - print(f"[RESOLVE DEBUG] is_multi_server_yaml={is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}") - print(f"[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = {isinstance(default_server_configs, ServerBaseline)}") + print( + "[RESOLVE DEBUG] is_multi_server_yaml=" + f"{is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}" + ) + print( + "[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = " + f"{isinstance(default_server_configs, ServerBaseline)}" + ) if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config: raise FailedExecutionException( diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 96242754..d3e7d2ea 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -2,6 +2,7 @@ # see example_trainer/vllm_api_server.py for an example import asyncio +import os import warnings import aiohttp @@ -189,6 +190,30 @@ class VLLMServer(APIServer): # Prepare request for VLLM native API request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0} request_data.update(kwargs) + debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1" + if debug_requests: + base = self.config.base_url.replace("/v1", "") + prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n") + print( + f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate " + f"prompt_token_len={len(prompt_tokens)}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] request_meta=" + f"{{'n': {request_data.get('n')}, 'max_tokens': {request_data.get('max_tokens')}, " + f"'temperature': {request_data.get('temperature')}, 'top_p': {request_data.get('top_p')}}}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}", + flush=True, + ) + print( + f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate " + '-H "Content-Type: application/json" -d \'\'', + flush=True, + ) # Make async request to VLLM /generate endpoint async with aiohttp.ClientSession() as session: diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 6ae5285b..0fa53431 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -123,7 +123,11 @@ class GSM8kEnv(BaseEnv): async def rollout_and_score_eval(self, question: str, answer: str) -> dict: """Rollout and score evaluation with detailed sample data collection.""" - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Important: use ManagedServer's default tokenizer resolution so it uses + # the underlying inference server tokenizer (e.g., Qwen) instead of the + # environment tokenizer. Passing self.tokenizer here can cause token-ID + # mismatch and gibberish generations when model/tokenizer families differ. + async with self.server.managed_server() as managed: completion = await managed.chat_completion( messages=[ {"role": "system", "content": system_prompt}, @@ -231,8 +235,16 @@ class GSM8kEnv(BaseEnv): gold_answer = ( "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" ) + question_preview = item["question"].replace("\n", " ")[:120] + print( + f"[GSM8K_DEBUG] collect_start group_size={self.config.group_size} " + f"q={question_preview!r}", + flush=True, + ) - async with self.server.managed_server(tokenizer=self.tokenizer) as managed: + # Important: do not force env tokenizer into ManagedServer for rollout. + # Let ManagedServer use the server's tokenizer to keep prompt token IDs aligned. + async with self.server.managed_server() as managed: chat_completions = await managed.chat_completion( messages=[{"role": "system", "content": system_prompt}, user_message], @@ -243,10 +255,24 @@ class GSM8kEnv(BaseEnv): state = managed.get_state() nodes = state["nodes"] + print( + f"[GSM8K_DEBUG] completion_batch_received choices={len(chat_completions.choices)} " + f"nodes={len(nodes)}", + flush=True, + ) to_score = list() to_backlog = list() for i, chat_completion in enumerate(chat_completions.choices): + response_text = chat_completion.message.content or "" + response_preview = response_text.replace("\n", " ")[:220] + valid_mask_count = sum(1 for m in nodes[i].masked_tokens if m != -100) + print( + f"[GSM8K_DEBUG] response_received idx={i} finish={chat_completion.finish_reason} " + f"tokens={len(nodes[i].tokens)} valid_masked={valid_mask_count} " + f"text={response_preview!r}", + flush=True, + ) messages = ( {"role": "system", "content": system_prompt}, user_message, @@ -263,6 +289,11 @@ class GSM8kEnv(BaseEnv): } ) to_postprocess = await self.score(to_score) + accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", [])) + print( + f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}", + flush=True, + ) return to_postprocess, to_backlog async def score( @@ -278,10 +309,15 @@ class GSM8kEnv(BaseEnv): extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) + print( + f"[GSM8K_DEBUG] score_start candidates={len(rollout_group_data)} " + f"gold_parsed_len={len(gold_parsed)}", + flush=True, + ) if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) random.shuffle(rollout_group_data) - for item in rollout_group_data: + for idx, item in enumerate(rollout_group_data): # print(item[0][-1]["content"]) answer_parsed = parse( item["messages"][-1]["content"].split("")[-1], @@ -310,7 +346,19 @@ class GSM8kEnv(BaseEnv): logprobs = item["logprobs"] # remove obviously bad examples - if len([1 for i in masks if i != -100]) < 10: + valid_mask_count = len([1 for i in masks if i != -100]) + print( + f"[GSM8K_DEBUG] score_candidate idx={idx} parsed_len={len(answer_parsed)} " + f"reward={bool(reward)} valid_masked={valid_mask_count} " + f"tokens={len(tokens)}", + flush=True, + ) + if valid_mask_count < 10: + print( + f"[GSM8K_DEBUG] drop_candidate idx={idx} reason=valid_masked_lt_10 " + f"value={valid_mask_count}", + flush=True, + ) continue scores["tokens"].append(tokens) scores["masks"].append(masks) @@ -323,6 +371,13 @@ class GSM8kEnv(BaseEnv): for score in scores["scores"]: self.percent_correct_buffer.append(max(score, 0)) + if len(scores["scores"]) == 0: + print( + "[GSM8K_DEBUG] drop_group reason=no_valid_candidates_after_filtering", + flush=True, + ) + return None + # check if all the same # print(scores['scores']) if all([score == 1 for score in scores["scores"]]): @@ -330,6 +385,10 @@ class GSM8kEnv(BaseEnv): token_lengths = [len(token) for token in scores["tokens"]] if max(token_lengths) == 0: # What? But don't want to crash a run so just in case... + print( + "[GSM8K_DEBUG] drop_group reason=zero_token_length_after_penalty_branch", + flush=True, + ) return None # Get max allowed token length from config @@ -353,10 +412,20 @@ class GSM8kEnv(BaseEnv): # Apply linear penalty scaling from 1.0 down to 0.0 scores["scores"].append(1.0 - percentage_of_range) if all([scores["scores"][0] == score for score in scores["scores"]]): + print( + f"[GSM8K_DEBUG] drop_group reason=all_scores_identical scores={scores['scores']}", + flush=True, + ) return None # If all the same, we return None + print( + f"[GSM8K_DEBUG] score_done accepted={len(scores['scores'])} " + f"scores={scores['scores']}", + flush=True, + ) return scores else: # If the gold solution is not parseable, we return None + print("[GSM8K_DEBUG] drop_group reason=gold_unparseable", flush=True) return None async def get_next_item(self) -> GSM8kRow: diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 93856ea3..12076dab 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -6,17 +6,12 @@ Original Repository: https://github.com/Open-Reasoner-Zero/Open-Reasoner-Zero import asyncio import random import re -import logging from concurrent.futures import ProcessPoolExecutor from typing import Dict, List, Optional, Tuple -import aiohttp import wandb from datasets import load_dataset -# Set up logging for debug -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify from math_verify.errors import TimeoutException @@ -129,8 +124,6 @@ class MathEnv(BaseEnv): slurm=True, testing=False, ): - print("Initializing MathEnv") - print(f"Slurm: {slurm}, Testing: {testing}") super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer = list() self.eval_metrics = list() @@ -141,17 +134,6 @@ class MathEnv(BaseEnv): self.normal_rollouts = list() self.pass_at_groupsize = list() self.iter = 0 - - # Debug: Print distillation config - print("=" * 60) - print("[MATH_DEBUG] DISTILLATION CONFIGURATION:") - print(f"[MATH_DEBUG] distillation_enabled = {config.distillation_enabled}") - print(f"[MATH_DEBUG] teacher_base_url = {config.teacher_base_url}") - print(f"[MATH_DEBUG] teacher_model_name = {getattr(config, 'teacher_model_name', 'N/A')}") - print(f"[MATH_DEBUG] teacher_top_k = {getattr(config, 'teacher_top_k', 'N/A')}") - print(f"[MATH_DEBUG] teacher_prefix_text set = {bool(getattr(config, 'teacher_prefix_text', None))}") - print(f"[MATH_DEBUG] teacher_system_prompt set = {bool(getattr(config, 'teacher_system_prompt', None))}") - print("=" * 60) @classmethod def config_init(cls) -> Tuple[RSConfig, ServerBaseline]: @@ -269,85 +251,7 @@ class MathEnv(BaseEnv): name, ) ) - - # Debug: Test teacher connectivity if distillation is enabled - if self.config.distillation_enabled and self.config.teacher_base_url: - await self._test_teacher_connectivity() - return - - async def _test_teacher_connectivity(self): - """Test if the teacher model API is reachable.""" - print("=" * 60) - print("[MATH_DEBUG] TESTING TEACHER CONNECTIVITY...") - print(f"[MATH_DEBUG] Teacher URL: {self.config.teacher_base_url}") - print(f"[MATH_DEBUG] Teacher Model: {getattr(self.config, 'teacher_model_name', 'default')}") - - try: - async with aiohttp.ClientSession() as session: - # Test 1: Health check - health_url = self.config.teacher_base_url.replace("/v1", "") + "/health" - print(f"[MATH_DEBUG] Testing health endpoint: {health_url}") - try: - async with session.get(health_url, timeout=aiohttp.ClientTimeout(total=10)) as resp: - print(f"[MATH_DEBUG] Health check status: {resp.status}") - if resp.status == 200: - print("[MATH_DEBUG] ✓ Teacher health check PASSED") - else: - print(f"[MATH_DEBUG] ✗ Teacher health check FAILED: {await resp.text()}") - except Exception as e: - print(f"[MATH_DEBUG] ✗ Teacher health check ERROR: {e}") - - # Test 2: Models endpoint - models_url = f"{self.config.teacher_base_url}/models" - print(f"[MATH_DEBUG] Testing models endpoint: {models_url}") - try: - async with session.get(models_url, timeout=aiohttp.ClientTimeout(total=10)) as resp: - print(f"[MATH_DEBUG] Models endpoint status: {resp.status}") - if resp.status == 200: - data = await resp.json() - models = [m.get("id", m) for m in data.get("data", [])] - print(f"[MATH_DEBUG] ✓ Available models: {models}") - else: - print(f"[MATH_DEBUG] ✗ Models endpoint FAILED: {await resp.text()}") - except Exception as e: - print(f"[MATH_DEBUG] ✗ Models endpoint ERROR: {e}") - - # Test 3: Simple completion test - completions_url = f"{self.config.teacher_base_url}/completions" - teacher_model = getattr(self.config, 'teacher_model_name', 'default') - test_payload = { - "model": teacher_model, - "prompt": "Hello", - "max_tokens": 5, - "logprobs": 5, - "echo": True, - } - print(f"[MATH_DEBUG] Testing completions endpoint: {completions_url}") - print(f"[MATH_DEBUG] Test payload: {test_payload}") - try: - async with session.post( - completions_url, - json=test_payload, - headers={"Content-Type": "application/json"}, - timeout=aiohttp.ClientTimeout(total=30), - ) as resp: - print(f"[MATH_DEBUG] Completions status: {resp.status}") - resp_text = await resp.text() - if resp.status == 200: - print(f"[MATH_DEBUG] ✓ Teacher completions WORKING!") - print(f"[MATH_DEBUG] Response preview: {resp_text[:500]}") - else: - print(f"[MATH_DEBUG] ✗ Teacher completions FAILED: {resp_text[:500]}") - except Exception as e: - print(f"[MATH_DEBUG] ✗ Teacher completions ERROR: {e}") - - except Exception as e: - print(f"[MATH_DEBUG] ✗ Teacher connectivity test FAILED: {e}") - import traceback - traceback.print_exc() - - print("=" * 60) async def rollout_and_score_eval(self, question, answer, subset): async with self.server.managed_server(tokenizer=self.tokenizer) as managed: @@ -492,7 +396,6 @@ class MathEnv(BaseEnv): ) if len(self.normal_rollouts) > self.config.num_rollouts_to_keep: self.normal_rollouts.pop(0) - print(f"Collected {len(to_postprocess['scores'])} trajectories") return to_postprocess, to_backlog async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: @@ -578,66 +481,7 @@ class MathEnv(BaseEnv): ] ) - # Debug: Log scored group creation - print(f"[MATH_DEBUG] Created ScoredDataGroup with {len(scores['tokens'])} sequences") - print(f"[MATH_DEBUG] Scores: {scores['scores']}") - print(f"[MATH_DEBUG] Token lengths: {[len(t) for t in scores['tokens']]}") - has_new_distill = ( - "distill_token_ids" in scores and "distill_logprobs" in scores - ) - print(f"[MATH_DEBUG] Has distill arrays: {has_new_distill}") - return scores - - async def handle_send_to_api( - self, - scored_data, - item=None, - do_send_to_api: bool = True, - abort_on_any_max_length_exceeded: bool = True, - ): - """Override to add debugging for distillation.""" - print(f"[MATH_DEBUG] handle_send_to_api called") - print(f"[MATH_DEBUG] distillation_enabled: {self.config.distillation_enabled}") - print(f"[MATH_DEBUG] teacher_base_url: {self.config.teacher_base_url}") - - if isinstance(scored_data, list): - for i, group in enumerate(scored_data): - if group: - has_distill = ( - group.get("distill_token_ids") is not None - and group.get("distill_logprobs") is not None - ) - print(f"[MATH_DEBUG] Group {i}: {len(group.get('tokens', []))} seqs, has_distill_logprobs={has_distill}") - elif scored_data: - has_distill = ( - scored_data.get("distill_token_ids") is not None - and scored_data.get("distill_logprobs") is not None - ) - print(f"[MATH_DEBUG] Single group: {len(scored_data.get('tokens', []))} seqs, has_distill_logprobs={has_distill}") - - # Call parent implementation which does the actual distillation fetch - result = await super().handle_send_to_api( - scored_data, item, do_send_to_api, abort_on_any_max_length_exceeded - ) - - # Debug: Check if distillation was added after parent call - if isinstance(scored_data, list): - for i, group in enumerate(scored_data): - if group: - has_distill = ( - group.get("distill_token_ids") is not None - and group.get("distill_logprobs") is not None - ) - print(f"[MATH_DEBUG] AFTER: Group {i} has_distill_logprobs={has_distill}") - elif scored_data: - has_distill = ( - scored_data.get("distill_token_ids") is not None - and scored_data.get("distill_logprobs") is not None - ) - print(f"[MATH_DEBUG] AFTER: Single group has_distill_logprobs={has_distill}") - - return result async def get_next_item(self): while True: @@ -652,10 +496,7 @@ class MathEnv(BaseEnv): ) break except TypeError: - print( - f"Error in getting next item, trying again, " - f"data: {next_item['question']} -> {next_item['final_answer']}" - ) + continue return (prompt, answer, "normal") From 3910a58f9baad93a9dadd1dc2ebd80433d138a51 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 20 Feb 2026 01:07:47 -0500 Subject: [PATCH 11/23] refactor base --- atroposlib/envs/base.py | 280 ++---------------- .../envs/server_handling/teacher_client.py | 273 +++++++++++++++++ 2 files changed, 290 insertions(+), 263 deletions(-) create mode 100644 atroposlib/envs/server_handling/teacher_client.py diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 551eb5b0..74383075 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -49,6 +49,7 @@ from .server_handling.server_manager import ( ServerManager, ServerManagerConfig, ) +from .server_handling.teacher_client import TeacherClient logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -297,6 +298,9 @@ class BaseEnv(ABC): self.curr_step = 0 self.max_token_len = -1 self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) + self.teacher_client = TeacherClient( + config=self.config, tokenizer=self.tokenizer, logger=logger + ) self.completion_lengths = [] self.max_num_workers = config.max_num_workers if self.max_num_workers == -1: @@ -358,174 +362,11 @@ class BaseEnv(ABC): messages_list: Optional[List[List[Dict]]] = None, top_k: Optional[int] = None, ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: - """ - Fetch top-K logprobs from teacher model for given sequences. - - Supports any OpenAI-compatible API (vLLM, OpenAI, Together, etc.). - - Args: - token_sequences: List of token ID sequences to get logprobs for - messages_list: Optional list of message histories (for chat APIs). - If provided, uses chat/completions with logprobs. - top_k: Number of top logprobs to fetch (defaults to config.teacher_top_k) - - Returns: - Tuple of (distill_token_ids, distill_logprobs), both shaped as: - [batch][position][top_k]. - Returns ([], []) if teacher_base_url is not configured. - """ - logger.info(f"[TEACHER] get_teacher_logprobs called with {len(token_sequences)} sequences") - logger.info(f"[TEACHER] teacher_base_url={self.config.teacher_base_url}") - - if not self.config.teacher_base_url: - logger.warning("[TEACHER] No teacher_base_url configured, returning empty") - return [], [] - - if top_k is None: - top_k = self.config.teacher_top_k - - # Get API key from config or environment - api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "") - model_name = self.config.teacher_model_name or "default" - - logger.info(f"[TEACHER] Using model={model_name}, top_k={top_k}") - - headers = {"Content-Type": "application/json"} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - token_id_results: List[List[List[int]]] = [] - logprob_results: List[List[List[float]]] = [] - - try: - async with aiohttp.ClientSession() as session: - for i, tokens in enumerate(token_sequences): - logger.info(f"[TEACHER] Processing sequence {i+1}/{len(token_sequences)}, {len(tokens)} tokens") - # Decode original sequence and optionally prepend teacher steering text. - base_text = self.tokenizer.decode(tokens, skip_special_tokens=False) - steering_prefix = "" - if self.config.teacher_system_prompt: - steering_prefix += ( - "System instruction:\n" - f"{self.config.teacher_system_prompt.strip()}\n\n" - ) - if self.config.teacher_prefix_text: - steering_prefix += self.config.teacher_prefix_text - full_text = steering_prefix + base_text - prefix_token_len = ( - len(self.tokenizer.encode(steering_prefix, add_special_tokens=False)) - if steering_prefix - else 0 - ) - - # Try vLLM-style completions first (supports prompt_logprobs) - # This is most efficient as it doesn't generate new tokens - request_data = { - "model": model_name, - "prompt": full_text, - "max_tokens": 1, - "temperature": 1.0, - "logprobs": top_k, - "echo": True, # Include prompt in response with logprobs - } - - try: - async with session.post( - f"{self.config.teacher_base_url}/completions", - json=request_data, - headers=headers, - timeout=aiohttp.ClientTimeout(total=120), - ) as response: - if response.status == 200: - data = await response.json() - seq_token_ids, seq_logprobs = self._parse_completion_logprobs( - data, top_k - ) - if seq_token_ids and seq_logprobs: - aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( - seq_token_ids, - seq_logprobs, - target_token_len=len(tokens), - prefix_token_len=prefix_token_len, - ) - token_id_results.append(aligned_ids) - logprob_results.append(aligned_lps) - continue - except Exception: - pass # Fall through to chat completions - - # Fallback: Use chat/completions with logprobs (OpenAI style) - # This requires messages format - if messages_list and i < len(messages_list): - messages = list(messages_list[i]) - if self.config.teacher_system_prompt: - messages = [ - { - "role": "system", - "content": self.config.teacher_system_prompt, - } - ] + messages - else: - # Convert text to simple message format - messages = [] - if self.config.teacher_system_prompt: - messages.append( - { - "role": "system", - "content": self.config.teacher_system_prompt, - } - ) - messages.append({"role": "user", "content": full_text}) - - chat_request = { - "model": model_name, - "messages": messages, - "max_tokens": 1, - "temperature": 1.0, - "logprobs": True, - "top_logprobs": top_k, - } - - try: - async with session.post( - f"{self.config.teacher_base_url}/chat/completions", - json=chat_request, - headers=headers, - timeout=aiohttp.ClientTimeout(total=120), - ) as response: - if response.status == 200: - data = await response.json() - seq_token_ids, seq_logprobs = self._parse_chat_logprobs( - data, top_k - ) - # Chat fallback logprobs are for generated tokens, not prompt tokens. - # To keep alignment correct for distillation, return empty per-position rows. - if seq_token_ids and len(seq_token_ids) >= len(tokens): - aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( - seq_token_ids, - seq_logprobs, - target_token_len=len(tokens), - prefix_token_len=0, - ) - else: - aligned_ids = [[] for _ in range(len(tokens))] - aligned_lps = [[] for _ in range(len(tokens))] - token_id_results.append(aligned_ids) - logprob_results.append(aligned_lps) - else: - logger.warning(f"Teacher API returned {response.status}") - token_id_results.append([[] for _ in range(len(tokens))]) - logprob_results.append([[] for _ in range(len(tokens))]) - except Exception as e: - logger.warning(f"Teacher chat request failed: {e}") - token_id_results.append([[] for _ in range(len(tokens))]) - logprob_results.append([[] for _ in range(len(tokens))]) - - return token_id_results, logprob_results - - except Exception as e: - logger.error(f"Error fetching teacher logprobs: {e}") - return [], [] + return await self.teacher_client.get_teacher_logprobs( + token_sequences=token_sequences, + messages_list=messages_list, + top_k=top_k, + ) def _align_teacher_topk_to_tokens( self, @@ -534,109 +375,22 @@ class BaseEnv(ABC): target_token_len: int, prefix_token_len: int = 0, ) -> Tuple[List[List[int]], List[List[float]]]: - """ - Trim teacher prefix positions and enforce exact length alignment with source tokens. - """ - n = min(len(seq_token_ids), len(seq_logprobs)) - aligned_ids = list(seq_token_ids[:n]) - aligned_lps = list(seq_logprobs[:n]) - - if prefix_token_len > 0: - aligned_ids = aligned_ids[prefix_token_len:] - aligned_lps = aligned_lps[prefix_token_len:] - - # Truncate any tail token (e.g., generated token when max_tokens>0 with echo=True) - aligned_ids = aligned_ids[:target_token_len] - aligned_lps = aligned_lps[:target_token_len] - - # Pad missing positions to preserve strict [seq][position][top_k] shape. - if len(aligned_ids) < target_token_len: - pad_count = target_token_len - len(aligned_ids) - aligned_ids.extend([[] for _ in range(pad_count)]) - aligned_lps.extend([[] for _ in range(pad_count)]) - - return aligned_ids, aligned_lps + return self.teacher_client._align_teacher_topk_to_tokens( + seq_token_ids=seq_token_ids, + seq_logprobs=seq_logprobs, + target_token_len=target_token_len, + prefix_token_len=prefix_token_len, + ) def _parse_completion_logprobs( self, data: Dict, top_k: int ) -> Tuple[List[List[int]], List[List[float]]]: - """Parse token ids + logprobs from vLLM-style completion response.""" - try: - choice = data.get("choices", [{}])[0] - logprobs_data = choice.get("logprobs", {}) - - # vLLM returns top_logprobs as list of dicts - top_logprobs = logprobs_data.get("top_logprobs", []) - - if not top_logprobs: - return [], [] - - seq_token_ids: List[List[int]] = [] - seq_logprobs: List[List[float]] = [] - for pos_logprobs in top_logprobs: - if pos_logprobs is None: - seq_token_ids.append([]) - seq_logprobs.append([]) - elif isinstance(pos_logprobs, dict): - # Format: {token_str: logprob, ...} - sorted_items = sorted( - pos_logprobs.items(), - key=lambda x: x[1], - reverse=True - )[:top_k] - pos_ids: List[int] = [] - pos_lps: List[float] = [] - for token_str, logprob in sorted_items: - # Convert token string to ID - token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) - if token_ids: - pos_ids.append(int(token_ids[0])) - pos_lps.append(float(logprob)) - seq_token_ids.append(pos_ids) - seq_logprobs.append(pos_lps) - else: - seq_token_ids.append([]) - seq_logprobs.append([]) - - return seq_token_ids, seq_logprobs - except Exception as e: - logger.warning(f"Error parsing completion logprobs: {e}") - return [], [] + return self.teacher_client._parse_completion_logprobs(data=data, top_k=top_k) def _parse_chat_logprobs( self, data: Dict, top_k: int ) -> Tuple[List[List[int]], List[List[float]]]: - """Parse token ids + logprobs from OpenAI-style chat completion response.""" - try: - choice = data.get("choices", [{}])[0] - logprobs_data = choice.get("logprobs", {}) - - if not logprobs_data: - return [], [] - - content = logprobs_data.get("content", []) - seq_token_ids: List[List[int]] = [] - seq_logprobs: List[List[float]] = [] - - for token_data in content: - top_logprobs = token_data.get("top_logprobs", []) - pos_ids: List[int] = [] - pos_lps: List[float] = [] - for item in top_logprobs[:top_k]: - token_str = item.get("token", "") - logprob = item.get("logprob", 0.0) - # Convert token string to ID - token_ids = self.tokenizer.encode(token_str, add_special_tokens=False) - if token_ids: - pos_ids.append(int(token_ids[0])) - pos_lps.append(float(logprob)) - seq_token_ids.append(pos_ids) - seq_logprobs.append(pos_lps) - - return seq_token_ids, seq_logprobs - except Exception as e: - logger.warning(f"Error parsing chat logprobs: {e}") - return [], [] + return self.teacher_client._parse_chat_logprobs(data=data, top_k=top_k) @classmethod def config_init( diff --git a/atroposlib/envs/server_handling/teacher_client.py b/atroposlib/envs/server_handling/teacher_client.py new file mode 100644 index 00000000..a929d7dc --- /dev/null +++ b/atroposlib/envs/server_handling/teacher_client.py @@ -0,0 +1,273 @@ +import os +from typing import Dict, List, Optional, Tuple + +import aiohttp + + +class TeacherClient: + """ + Transport/parsing client for teacher top-k logprobs. + + This keeps distillation HTTP and parsing logic out of BaseEnv. + """ + + def __init__(self, config, tokenizer, logger): + self.config = config + self.tokenizer = tokenizer + self.logger = logger + + async def get_teacher_logprobs( + self, + token_sequences: List[List[int]], + messages_list: Optional[List[List[Dict]]] = None, + top_k: Optional[int] = None, + ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: + self.logger.info( + "[TEACHER] get_teacher_logprobs called with %s sequences", + len(token_sequences), + ) + self.logger.info("[TEACHER] teacher_base_url=%s", self.config.teacher_base_url) + + if not self.config.teacher_base_url: + self.logger.warning("[TEACHER] No teacher_base_url configured, returning empty") + return [], [] + + if top_k is None: + top_k = self.config.teacher_top_k + + api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "") + model_name = self.config.teacher_model_name or "default" + self.logger.info("[TEACHER] Using model=%s, top_k=%s", model_name, top_k) + + headers = {"Content-Type": "application/json"} + if api_key: + headers["Authorization"] = f"Bearer {api_key}" + + token_id_results: List[List[List[int]]] = [] + logprob_results: List[List[List[float]]] = [] + + try: + async with aiohttp.ClientSession() as session: + for i, tokens in enumerate(token_sequences): + self.logger.info( + "[TEACHER] Processing sequence %s/%s, %s tokens", + i + 1, + len(token_sequences), + len(tokens), + ) + base_text = self.tokenizer.decode(tokens, skip_special_tokens=False) + steering_prefix = "" + if self.config.teacher_system_prompt: + steering_prefix += ( + "System instruction:\n" + f"{self.config.teacher_system_prompt.strip()}\n\n" + ) + if self.config.teacher_prefix_text: + steering_prefix += self.config.teacher_prefix_text + full_text = steering_prefix + base_text + prefix_token_len = ( + len( + self.tokenizer.encode( + steering_prefix, add_special_tokens=False + ) + ) + if steering_prefix + else 0 + ) + + request_data = { + "model": model_name, + "prompt": full_text, + "max_tokens": 1, + "temperature": 1.0, + "logprobs": top_k, + "echo": True, + } + try: + async with session.post( + f"{self.config.teacher_base_url}/completions", + json=request_data, + headers=headers, + timeout=aiohttp.ClientTimeout(total=120), + ) as response: + if response.status == 200: + data = await response.json() + seq_token_ids, seq_logprobs = self._parse_completion_logprobs( + data, top_k + ) + if seq_token_ids and seq_logprobs: + aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( + seq_token_ids, + seq_logprobs, + target_token_len=len(tokens), + prefix_token_len=prefix_token_len, + ) + token_id_results.append(aligned_ids) + logprob_results.append(aligned_lps) + continue + except Exception: + pass + + if messages_list and i < len(messages_list): + messages = list(messages_list[i]) + if self.config.teacher_system_prompt: + messages = [ + { + "role": "system", + "content": self.config.teacher_system_prompt, + } + ] + messages + else: + messages = [] + if self.config.teacher_system_prompt: + messages.append( + { + "role": "system", + "content": self.config.teacher_system_prompt, + } + ) + messages.append({"role": "user", "content": full_text}) + + chat_request = { + "model": model_name, + "messages": messages, + "max_tokens": 1, + "temperature": 1.0, + "logprobs": True, + "top_logprobs": top_k, + } + try: + async with session.post( + f"{self.config.teacher_base_url}/chat/completions", + json=chat_request, + headers=headers, + timeout=aiohttp.ClientTimeout(total=120), + ) as response: + if response.status == 200: + data = await response.json() + seq_token_ids, seq_logprobs = self._parse_chat_logprobs( + data, top_k + ) + if seq_token_ids and len(seq_token_ids) >= len(tokens): + aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( + seq_token_ids, + seq_logprobs, + target_token_len=len(tokens), + prefix_token_len=0, + ) + else: + aligned_ids = [[] for _ in range(len(tokens))] + aligned_lps = [[] for _ in range(len(tokens))] + token_id_results.append(aligned_ids) + logprob_results.append(aligned_lps) + else: + self.logger.warning( + "Teacher API returned %s", response.status + ) + token_id_results.append([[] for _ in range(len(tokens))]) + logprob_results.append([[] for _ in range(len(tokens))]) + except Exception as e: + self.logger.warning("Teacher chat request failed: %s", e) + token_id_results.append([[] for _ in range(len(tokens))]) + logprob_results.append([[] for _ in range(len(tokens))]) + + return token_id_results, logprob_results + except Exception as e: + self.logger.error("Error fetching teacher logprobs: %s", e) + return [], [] + + def _align_teacher_topk_to_tokens( + self, + seq_token_ids: List[List[int]], + seq_logprobs: List[List[float]], + target_token_len: int, + prefix_token_len: int = 0, + ) -> Tuple[List[List[int]], List[List[float]]]: + n = min(len(seq_token_ids), len(seq_logprobs)) + aligned_ids = list(seq_token_ids[:n]) + aligned_lps = list(seq_logprobs[:n]) + + if prefix_token_len > 0: + aligned_ids = aligned_ids[prefix_token_len:] + aligned_lps = aligned_lps[prefix_token_len:] + + aligned_ids = aligned_ids[:target_token_len] + aligned_lps = aligned_lps[:target_token_len] + + if len(aligned_ids) < target_token_len: + pad_count = target_token_len - len(aligned_ids) + aligned_ids.extend([[] for _ in range(pad_count)]) + aligned_lps.extend([[] for _ in range(pad_count)]) + + return aligned_ids, aligned_lps + + def _parse_completion_logprobs( + self, data: Dict, top_k: int + ) -> Tuple[List[List[int]], List[List[float]]]: + try: + choice = data.get("choices", [{}])[0] + logprobs_data = choice.get("logprobs", {}) + top_logprobs = logprobs_data.get("top_logprobs", []) + if not top_logprobs: + return [], [] + + seq_token_ids: List[List[int]] = [] + seq_logprobs: List[List[float]] = [] + for pos_logprobs in top_logprobs: + if pos_logprobs is None: + seq_token_ids.append([]) + seq_logprobs.append([]) + elif isinstance(pos_logprobs, dict): + sorted_items = sorted( + pos_logprobs.items(), key=lambda x: x[1], reverse=True + )[:top_k] + pos_ids: List[int] = [] + pos_lps: List[float] = [] + for token_str, logprob in sorted_items: + token_ids = self.tokenizer.encode( + token_str, add_special_tokens=False + ) + if token_ids: + pos_ids.append(int(token_ids[0])) + pos_lps.append(float(logprob)) + seq_token_ids.append(pos_ids) + seq_logprobs.append(pos_lps) + else: + seq_token_ids.append([]) + seq_logprobs.append([]) + return seq_token_ids, seq_logprobs + except Exception as e: + self.logger.warning("Error parsing completion logprobs: %s", e) + return [], [] + + def _parse_chat_logprobs( + self, data: Dict, top_k: int + ) -> Tuple[List[List[int]], List[List[float]]]: + try: + choice = data.get("choices", [{}])[0] + logprobs_data = choice.get("logprobs", {}) + if not logprobs_data: + return [], [] + + content = logprobs_data.get("content", []) + seq_token_ids: List[List[int]] = [] + seq_logprobs: List[List[float]] = [] + for token_data in content: + top_logprobs = token_data.get("top_logprobs", []) + pos_ids: List[int] = [] + pos_lps: List[float] = [] + for item in top_logprobs[:top_k]: + token_str = item.get("token", "") + logprob = item.get("logprob", 0.0) + token_ids = self.tokenizer.encode( + token_str, add_special_tokens=False + ) + if token_ids: + pos_ids.append(int(token_ids[0])) + pos_lps.append(float(logprob)) + seq_token_ids.append(pos_ids) + seq_logprobs.append(pos_lps) + return seq_token_ids, seq_logprobs + except Exception as e: + self.logger.warning("Error parsing chat logprobs: %s", e) + return [], [] From 559d649a263b89ce620024fadf61db4a4f212be9 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 20 Feb 2026 01:22:50 -0500 Subject: [PATCH 12/23] proper fallback --- README.md | 2 +- atroposlib/envs/base.py | 7 ++- .../envs/server_handling/teacher_client.py | 56 ++++++++++++++++++- 3 files changed, 61 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ddb2bac7..ef90b729 100644 --- a/README.md +++ b/README.md @@ -312,7 +312,7 @@ Atropos handles tokenization in two places: - Recommendation: set `--openai.tokenizer_name` explicitly to match the student serving model. 2. **Teacher top-k parsing path** - - Teacher responses are parsed into token ids/logprobs in `BaseEnv.get_teacher_logprobs`. + - Teacher responses are fetched/parsed in `TeacherClient.get_teacher_logprobs` (called by `BaseEnv`). - The parser maps teacher token strings into ids using the environment tokenizer (`self.tokenizer`) and then aligns to student sequence length. Because distillation is token-position based, keeping tokenizer families compatible is strongly recommended, especially for cross-model distillation. diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 74383075..b816ce66 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -990,7 +990,12 @@ class BaseEnv(ABC): if self.config.include_messages and group.get("messages") is None: group["messages"] = [ - self.tokenizer.decode(group["tokens"][i]) + [ + { + "role": "user", + "content": self.tokenizer.decode(group["tokens"][i]), + } + ] for i in range(len(group["tokens"])) ] diff --git a/atroposlib/envs/server_handling/teacher_client.py b/atroposlib/envs/server_handling/teacher_client.py index a929d7dc..af5897b4 100644 --- a/atroposlib/envs/server_handling/teacher_client.py +++ b/atroposlib/envs/server_handling/teacher_client.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import aiohttp @@ -15,6 +15,28 @@ class TeacherClient: self.config = config self.tokenizer = tokenizer self.logger = logger + if self.config.distillation_enabled: + self._validate_distillation_config() + + def _validate_distillation_config(self) -> None: + if not self.config.teacher_base_url: + raise ValueError("Distillation requires `teacher_base_url` to be set.") + if self.config.teacher_top_k <= 0: + raise ValueError( + f"Distillation requires `teacher_top_k > 0`, got {self.config.teacher_top_k}." + ) + student_model_name = getattr(self.config, "model_name", None) + if ( + self.config.teacher_model_name + and student_model_name + and self.config.teacher_model_name != student_model_name + ): + self.logger.warning( + "Cross-model distillation configured (teacher=%s, student=%s). " + "Token-level alignment quality depends on tokenizer compatibility.", + self.config.teacher_model_name, + student_model_name, + ) async def get_teacher_logprobs( self, @@ -109,7 +131,7 @@ class TeacherClient: pass if messages_list and i < len(messages_list): - messages = list(messages_list[i]) + messages = self._normalize_messages(messages_list[i], full_text) if self.config.teacher_system_prompt: messages = [ { @@ -176,6 +198,36 @@ class TeacherClient: self.logger.error("Error fetching teacher logprobs: %s", e) return [], [] + def _normalize_messages( + self, raw_messages: Any, fallback_text: str + ) -> List[Dict[str, str]]: + """ + Normalize environment message payloads for chat/completions teacher fallback. + + Accepts already-structured message lists, plain strings, or unknown structures. + """ + if isinstance(raw_messages, str): + return [{"role": "user", "content": raw_messages}] + + if isinstance(raw_messages, list): + normalized: List[Dict[str, str]] = [] + for msg in raw_messages: + if ( + isinstance(msg, dict) + and "role" in msg + and "content" in msg + and isinstance(msg["content"], str) + ): + normalized.append( + {"role": str(msg["role"]), "content": msg["content"]} + ) + elif isinstance(msg, str): + normalized.append({"role": "user", "content": msg}) + if normalized: + return normalized + + return [{"role": "user", "content": fallback_text}] + def _align_teacher_topk_to_tokens( self, seq_token_ids: List[List[int]], From e615eb1f506adba9fdd5104ccab64849b72e35d6 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 20 Feb 2026 02:16:49 -0500 Subject: [PATCH 13/23] assertions --- atroposlib/envs/base.py | 5 ++ .../envs/server_handling/openai_server.py | 23 ------ .../envs/server_handling/teacher_client.py | 81 +++++++++++++++++++ 3 files changed, 86 insertions(+), 23 deletions(-) diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index b816ce66..756a4cad 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1034,6 +1034,11 @@ class BaseEnv(ABC): logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}") import traceback logger.error(traceback.format_exc()) + self.teacher_client.assert_distill_arrays_aligned( + token_sequences=group["tokens"], + distill_token_ids=group.get("distill_token_ids"), + distill_logprobs=group.get("distill_logprobs"), + ) else: logger.debug( "[DISTILL] Skipped - enabled=%s, url=%s", diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index f84558c4..24582273 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -165,17 +165,12 @@ def resolve_openai_configs( """ from atroposlib.envs.server_handling.server_manager import ServerBaseline - print(f"[RESOLVE DEBUG] default_server_configs type = {type(default_server_configs)}") - print(f"[RESOLVE DEBUG] openai_config_dict = {openai_config_dict}") - openai_full_prefix = f"{OPENAI_NAMESPACE}{NAMESPACE_SEP}" openai_yaml_config = yaml_config.get(OPENAI_NAMESPACE, None) openai_cli_config = { k: v for k, v in cli_passed_flags.items() if k.startswith(openai_full_prefix) } - print(f"[RESOLVE DEBUG] openai_cli_config = {openai_cli_config}") - is_multi_server_yaml = ( isinstance(openai_yaml_config, list) and len(openai_yaml_config) >= 2 ) @@ -185,15 +180,6 @@ def resolve_openai_configs( and len(default_server_configs) >= 2 ) - print( - "[RESOLVE DEBUG] is_multi_server_yaml=" - f"{is_multi_server_yaml}, is_multi_server_default={is_multi_server_default}" - ) - print( - "[RESOLVE DEBUG] isinstance(default_server_configs, ServerBaseline) = " - f"{isinstance(default_server_configs, ServerBaseline)}" - ) - if (is_multi_server_yaml or is_multi_server_default) and openai_cli_config: raise FailedExecutionException( message=f"CLI overrides for OpenAI settings (--{openai_full_prefix}*) are not supported " @@ -203,7 +189,6 @@ def resolve_openai_configs( ) if is_multi_server_yaml: - print("[RESOLVE DEBUG] Taking multi-server YAML path") logger.info( f"Using multi-server configuration defined in YAML under '{OPENAI_NAMESPACE}'." ) @@ -215,7 +200,6 @@ def resolve_openai_configs( ) from e elif isinstance(default_server_configs, APIServerConfig): # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline - print("[RESOLVE DEBUG] Taking APIServerConfig merged path") logger.info("Using single OpenAI server configuration based on merged settings (default/YAML/CLI).") try: final_openai_config = APIServerConfig(**openai_config_dict) @@ -227,15 +211,12 @@ def resolve_openai_configs( server_configs = final_openai_config elif isinstance(default_server_configs, ServerBaseline): # Pure ServerBaseline (not APIServerConfig) - no CLI overrides possible - print("[RESOLVE DEBUG] Taking ServerBaseline path") logger.info("Using ServerBaseline configuration.") server_configs = default_server_configs elif is_multi_server_default: - print("[RESOLVE DEBUG] Taking multi-server default path") logger.info("Using default multi-server configuration (length >= 2).") server_configs = default_server_configs else: - print("[RESOLVE DEBUG] Taking single server merged path (fallback)") logger.info( "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." ) @@ -247,12 +228,9 @@ def resolve_openai_configs( f"Merged Dict: {openai_config_dict}" ) from e - print(f"[RESOLVE DEBUG] final_openai_config = {final_openai_config}") if isinstance(default_server_configs, APIServerConfig): - print("[RESOLVE DEBUG] Returning final_openai_config directly") server_configs = final_openai_config elif isinstance(default_server_configs, list): - print("[RESOLVE DEBUG] Returning [final_openai_config]") server_configs = [final_openai_config] else: logger.warning( @@ -261,5 +239,4 @@ def resolve_openai_configs( ) server_configs = [final_openai_config] - print(f"[RESOLVE DEBUG] Returning server_configs = {server_configs}") return server_configs diff --git a/atroposlib/envs/server_handling/teacher_client.py b/atroposlib/envs/server_handling/teacher_client.py index af5897b4..10dfc1ca 100644 --- a/atroposlib/envs/server_handling/teacher_client.py +++ b/atroposlib/envs/server_handling/teacher_client.py @@ -177,6 +177,9 @@ class TeacherClient: target_token_len=len(tokens), prefix_token_len=0, ) + aligned_ids, aligned_lps = self._normalize_aligned_rows( + aligned_ids, aligned_lps, top_k + ) else: aligned_ids = [[] for _ in range(len(tokens))] aligned_lps = [[] for _ in range(len(tokens))] @@ -198,6 +201,84 @@ class TeacherClient: self.logger.error("Error fetching teacher logprobs: %s", e) return [], [] + def _normalize_aligned_rows( + self, + seq_token_ids: List[List[int]], + seq_logprobs: List[List[float]], + top_k: int, + ) -> Tuple[List[List[int]], List[List[float]]]: + """ + Enforce per-position alignment invariants: + - same number of positions in ids and logprobs + - same number of top-k entries per position + - cap each position to <= top_k + """ + normalized_ids: List[List[int]] = [] + normalized_lps: List[List[float]] = [] + n_positions = max(len(seq_token_ids), len(seq_logprobs)) + for pos in range(n_positions): + ids = ( + seq_token_ids[pos] + if pos < len(seq_token_ids) and isinstance(seq_token_ids[pos], list) + else [] + ) + lps = ( + seq_logprobs[pos] + if pos < len(seq_logprobs) and isinstance(seq_logprobs[pos], list) + else [] + ) + n = min(len(ids), len(lps), top_k) + normalized_ids.append([int(x) for x in ids[:n]]) + normalized_lps.append([float(x) for x in lps[:n]]) + return normalized_ids, normalized_lps + + def assert_distill_arrays_aligned( + self, + token_sequences: List[List[int]], + distill_token_ids: Optional[List[List[List[int]]]], + distill_logprobs: Optional[List[List[List[float]]]], + ) -> None: + """ + Strict OPD invariant checks: + - both arrays exist + - one sequence row per token sequence + - one position row per token position + - ids/logprobs top-k row lengths match at each position + """ + if distill_token_ids is None or distill_logprobs is None: + raise AssertionError( + "[DISTILL] distill_token_ids/distill_logprobs must both be present." + ) + + if len(distill_token_ids) != len(token_sequences) or len(distill_logprobs) != len( + token_sequences + ): + raise AssertionError( + "[DISTILL] sequence count mismatch: " + f"tokens={len(token_sequences)} ids={len(distill_token_ids)} " + f"lps={len(distill_logprobs)}" + ) + + for seq_idx, tokens in enumerate(token_sequences): + expected_positions = len(tokens) + seq_ids = distill_token_ids[seq_idx] + seq_lps = distill_logprobs[seq_idx] + + if len(seq_ids) != expected_positions or len(seq_lps) != expected_positions: + raise AssertionError( + "[DISTILL] position count mismatch at seq " + f"{seq_idx}: tokens={expected_positions} ids={len(seq_ids)} " + f"lps={len(seq_lps)}" + ) + + for pos_idx in range(expected_positions): + if len(seq_ids[pos_idx]) != len(seq_lps[pos_idx]): + raise AssertionError( + "[DISTILL] top-k row mismatch at " + f"seq={seq_idx}, pos={pos_idx}: " + f"ids={len(seq_ids[pos_idx])}, lps={len(seq_lps[pos_idx])}" + ) + def _normalize_messages( self, raw_messages: Any, fallback_text: str ) -> List[Dict[str, str]]: From 55f7cbd091fedc33a6fa11f56131a2096678c30c Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 20 Feb 2026 03:14:05 -0500 Subject: [PATCH 14/23] dynamic system prompts --- README.md | 56 ++++++ atroposlib/envs/base.py | 19 ++ .../envs/server_handling/teacher_client.py | 177 ++++++++++++++++-- 3 files changed, 240 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index ef90b729..1ea538b0 100644 --- a/README.md +++ b/README.md @@ -289,6 +289,62 @@ Distillation is configured in `BaseEnvConfig` and available via CLI under `--env - Optional steering controls: - `--env.teacher_prefix_text "..."` - `--env.teacher_system_prompt "..."` + - `--env.teacher_prompt_template "Question: {question}\nAnswer: {answer}\n{episodes}"` + +Template-first prompting is the recommended dynamic interface. + +#### Dynamic system prompts (step by step) + +1. **Define one reusable template at config time** + - Put stable policy/rules in `--env.teacher_prompt_template`. + - Use variable placeholders for runtime values, e.g. `{question}`, `{answer}`, `{episodes}`. + - Keep student trajectory out of the template unless you intentionally want duplication. + +2. **Pass runtime variables from the environment** + - At scoring time, attach variables to: + - `group_overrides.teacher_prompt_context` for one context shared by the group, or + - `overrides[i].teacher_prompt_context` for per-sequence customization. + - Alias key `teacher_prompt_variables` is also accepted. + +3. **(Optional) swap template per turn** + - You can set `group_overrides.teacher_prompt_template` or `overrides[i].teacher_prompt_template` + to change template structure on specific turns/samples. + +4. **Understand precedence** + - Per-sequence (`overrides[i]`) > group-level (`group_overrides`) > env config defaults. + - This lets you define a strong default template while overriding only special cases. + +5. **Know what the teacher actually sees** + - Teacher prompt is built as: rendered steering prefix + current student sequence text. + - Distillation still aligns to student token positions after prefix trimming. + +#### Example template + +```text +You are a math teacher supervising a solution process. + +Hidden reference answer: +{answer} + +Rules: +1) Do not reveal the hidden reference directly. +2) Re-derive from first principles. +3) Give the final answer only after derivation in \boxed{{...}}. + +Question: +{question} +``` + +#### Example runtime injection (inside env scoring) + +```python +scores["group_overrides"] = { + "teacher_prompt_context": { + "question": rollout_group_data[0]["messages"][1]["content"], + "answer": rollout_group_data[0]["gold_answer"], + } +} +``` ### Self-distillation vs cross-model distillation diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 756a4cad..f49dc3d5 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -248,6 +248,13 @@ class BaseEnvConfig(BaseModel): "this is converted to a textual prefix. For chat fallback, this is injected " "as a leading system message.", ) + teacher_prompt_template: Optional[str] = Field( + default=None, + description="Optional template-first teacher prompt renderer. " + "Uses Python format-style variables from runtime context/overrides " + "(e.g., {question}, {answer}, {episodes}). If set, this is preferred over " + "mode-specific prompt building.", + ) class BaseEnv(ABC): @@ -360,11 +367,15 @@ class BaseEnv(ABC): self, token_sequences: List[List[int]], messages_list: Optional[List[List[Dict]]] = None, + seq_overrides: Optional[List[Dict[str, Any]]] = None, + group_overrides: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: return await self.teacher_client.get_teacher_logprobs( token_sequences=token_sequences, messages_list=messages_list, + seq_overrides=seq_overrides, + group_overrides=group_overrides, top_k=top_k, ) @@ -1012,6 +1023,12 @@ class BaseEnv(ABC): if self.config.distillation_enabled and self.config.teacher_base_url: logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups") for group in valid_groups: + seq_overrides = group.get("overrides") or [] + group_overrides = ( + group.get("group_overrides") + if isinstance(group.get("group_overrides"), dict) + else {} + ) has_new_format = ( group.get("distill_token_ids") is not None and group.get("distill_logprobs") is not None @@ -1021,6 +1038,8 @@ class BaseEnv(ABC): teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs( token_sequences=group["tokens"], messages_list=group.get("messages"), + seq_overrides=seq_overrides, + group_overrides=group_overrides, ) if teacher_token_ids and teacher_logprobs: group["distill_token_ids"] = teacher_token_ids diff --git a/atroposlib/envs/server_handling/teacher_client.py b/atroposlib/envs/server_handling/teacher_client.py index 10dfc1ca..c79a6b83 100644 --- a/atroposlib/envs/server_handling/teacher_client.py +++ b/atroposlib/envs/server_handling/teacher_client.py @@ -1,3 +1,4 @@ +import json import os from typing import Any, Dict, List, Optional, Tuple @@ -42,6 +43,8 @@ class TeacherClient: self, token_sequences: List[List[int]], messages_list: Optional[List[List[Dict]]] = None, + seq_overrides: Optional[List[Dict[str, Any]]] = None, + group_overrides: Optional[Dict[str, Any]] = None, top_k: Optional[int] = None, ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: self.logger.info( @@ -78,14 +81,24 @@ class TeacherClient: len(tokens), ) base_text = self.tokenizer.decode(tokens, skip_special_tokens=False) - steering_prefix = "" - if self.config.teacher_system_prompt: - steering_prefix += ( - "System instruction:\n" - f"{self.config.teacher_system_prompt.strip()}\n\n" - ) - if self.config.teacher_prefix_text: - steering_prefix += self.config.teacher_prefix_text + ( + seq_system_prompt, + seq_prefix_text, + seq_prompt_mode, + seq_prompt_context, + seq_prompt_template, + ) = self._resolve_prompt_overrides_for_sequence( + seq_idx=i, + seq_overrides=seq_overrides, + group_overrides=group_overrides, + ) + steering_prefix = self._build_teacher_steering_prefix( + system_prompt=seq_system_prompt, + prefix_text=seq_prefix_text, + mode=seq_prompt_mode, + context=seq_prompt_context, + template=seq_prompt_template, + ) full_text = steering_prefix + base_text prefix_token_len = ( len( @@ -124,6 +137,9 @@ class TeacherClient: target_token_len=len(tokens), prefix_token_len=prefix_token_len, ) + aligned_ids, aligned_lps = self._normalize_aligned_rows( + aligned_ids, aligned_lps, top_k + ) token_id_results.append(aligned_ids) logprob_results.append(aligned_lps) continue @@ -132,20 +148,20 @@ class TeacherClient: if messages_list and i < len(messages_list): messages = self._normalize_messages(messages_list[i], full_text) - if self.config.teacher_system_prompt: + if seq_system_prompt: messages = [ { "role": "system", - "content": self.config.teacher_system_prompt, + "content": seq_system_prompt, } ] + messages else: messages = [] - if self.config.teacher_system_prompt: + if seq_system_prompt: messages.append( { "role": "system", - "content": self.config.teacher_system_prompt, + "content": seq_system_prompt, } ) messages.append({"role": "user", "content": full_text}) @@ -201,6 +217,138 @@ class TeacherClient: self.logger.error("Error fetching teacher logprobs: %s", e) return [], [] + def _resolve_prompt_overrides_for_sequence( + self, + seq_idx: int, + seq_overrides: Optional[List[Dict[str, Any]]], + group_overrides: Optional[Dict[str, Any]], + ) -> Tuple[Optional[str], Optional[str], str, Dict[str, Any], Optional[str]]: + group_overrides = group_overrides or {} + seq_override: Dict[str, Any] = {} + if ( + seq_overrides is not None + and seq_idx < len(seq_overrides) + and isinstance(seq_overrides[seq_idx], dict) + ): + seq_override = seq_overrides[seq_idx] + + seq_system_prompt = seq_override.get( + "teacher_system_prompt", + group_overrides.get("teacher_system_prompt", self.config.teacher_system_prompt), + ) + seq_prefix_text = seq_override.get( + "teacher_prefix_text", + group_overrides.get("teacher_prefix_text", self.config.teacher_prefix_text), + ) + seq_prompt_mode = seq_override.get( + "teacher_prompt_mode", + group_overrides.get("teacher_prompt_mode", "default"), + ) + # `teacher_prompt_variables` is accepted as an alias for template-style usage. + seq_prompt_context = seq_override.get("teacher_prompt_context") + if seq_prompt_context is None: + seq_prompt_context = seq_override.get("teacher_prompt_variables") + if seq_prompt_context is None: + seq_prompt_context = group_overrides.get("teacher_prompt_context") + if seq_prompt_context is None: + seq_prompt_context = group_overrides.get("teacher_prompt_variables") + if seq_prompt_context is None: + seq_prompt_context = {} + if not isinstance(seq_prompt_context, dict): + seq_prompt_context = {} + seq_prompt_template = seq_override.get( + "teacher_prompt_template", + group_overrides.get( + "teacher_prompt_template", self.config.teacher_prompt_template + ), + ) + return ( + seq_system_prompt, + seq_prefix_text, + seq_prompt_mode, + seq_prompt_context, + seq_prompt_template, + ) + + def _build_teacher_steering_prefix( + self, + system_prompt: Optional[str], + prefix_text: Optional[str], + mode: str, + context: Dict[str, Any], + template: Optional[str], + ) -> str: + base_parts: List[str] = [] + if system_prompt: + base_parts.append(f"System instruction:\n{system_prompt.strip()}\n") + if prefix_text: + base_parts.append(str(prefix_text)) + base = "\n".join(x for x in base_parts if x).strip() + + normalized_mode = (mode or "default").strip().lower() + ctx = context or {} + + # Template-first path (recommended): render once with runtime variables. + if template: + template_vars = self._prepare_template_vars( + context=ctx, system_prompt=system_prompt, prefix_text=prefix_text + ) + try: + rendered = template.format_map(_SafeFormatDict(template_vars)) + except Exception: + rendered = template + return f"{rendered}\n\n" if rendered else "" + + if normalized_mode == "answer_context": + answer = ctx.get("answer") + if answer: + return ( + f"{base}\n\nReference answer/context:\n{answer}\n\n" + if base + else f"Reference answer/context:\n{answer}\n\n" + ) + return f"{base}\n\n" if base else "" + + if normalized_mode == "history_context": + episodes = ctx.get("episodes") + if isinstance(episodes, list) and episodes: + episode_lines = [f"Episode {idx + 1}: {ep}" for idx, ep in enumerate(episodes)] + history_block = "\n".join(episode_lines) + return ( + f"{base}\n\nPrevious episodes:\n{history_block}\n\n" + if base + else f"Previous episodes:\n{history_block}\n\n" + ) + return f"{base}\n\n" if base else "" + + return f"{base}\n\n" if base else "" + + def _prepare_template_vars( + self, + context: Dict[str, Any], + system_prompt: Optional[str], + prefix_text: Optional[str], + ) -> Dict[str, Any]: + """ + Build template variables with convenience aliases for common dynamic fields. + """ + template_vars = dict(context) + template_vars.setdefault("system_prompt", system_prompt or "") + template_vars.setdefault("prefix_text", prefix_text or "") + template_vars.setdefault("answer", context.get("answer", "")) + template_vars.setdefault("question", context.get("question", "")) + episodes = context.get("episodes") + if isinstance(episodes, list): + template_vars.setdefault( + "episodes", + "\n".join(f"Episode {idx + 1}: {ep}" for idx, ep in enumerate(episodes)), + ) + template_vars.setdefault("episodes_json", json.dumps(episodes, ensure_ascii=True)) + else: + template_vars.setdefault("episodes", "") + template_vars.setdefault("episodes_json", "[]") + return template_vars + def _normalize_aligned_rows( self, seq_token_ids: List[List[int]], @@ -404,3 +552,8 @@ class TeacherClient: except Exception as e: self.logger.warning("Error parsing chat logprobs: %s", e) return [], [] + + +class _SafeFormatDict(dict): + def __missing__(self, key): + return "{" + key + "}" From 63007d1209e9cf6c4e0dc208f5c6ad1190c1180a Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 20 Feb 2026 03:16:27 -0500 Subject: [PATCH 15/23] dynamic system prompts --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 1ea538b0..b0ca8d15 100644 --- a/README.md +++ b/README.md @@ -298,7 +298,7 @@ Template-first prompting is the recommended dynamic interface. 1. **Define one reusable template at config time** - Put stable policy/rules in `--env.teacher_prompt_template`. - Use variable placeholders for runtime values, e.g. `{question}`, `{answer}`, `{episodes}`. - - Keep student trajectory out of the template unless you intentionally want duplication. + 2. **Pass runtime variables from the environment** - At scoring time, attach variables to: From fc248dd65bfee05ea918f50d51a118762cddb734 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 20 Feb 2026 12:01:50 -0500 Subject: [PATCH 16/23] clean --- .../envs/server_handling/vllm_server.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index d3e7d2ea..96242754 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -2,7 +2,6 @@ # see example_trainer/vllm_api_server.py for an example import asyncio -import os import warnings import aiohttp @@ -190,30 +189,6 @@ class VLLMServer(APIServer): # Prepare request for VLLM native API request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0} request_data.update(kwargs) - debug_requests = os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1" - if debug_requests: - base = self.config.base_url.replace("/v1", "") - prompt_preview = self.tokenizer.decode(prompt_tokens[:256]).replace("\n", "\\n") - print( - f"[ATROPOS_REQ_DEBUG] vllm_generate_url={base}/generate " - f"prompt_token_len={len(prompt_tokens)}", - flush=True, - ) - print( - f"[ATROPOS_REQ_DEBUG] request_meta=" - f"{{'n': {request_data.get('n')}, 'max_tokens': {request_data.get('max_tokens')}, " - f"'temperature': {request_data.get('temperature')}, 'top_p': {request_data.get('top_p')}}}", - flush=True, - ) - print( - f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}", - flush=True, - ) - print( - f"[ATROPOS_REQ_DEBUG] curl_base=curl -s -X POST {base}/generate " - '-H "Content-Type: application/json" -d \'\'', - flush=True, - ) # Make async request to VLLM /generate endpoint async with aiohttp.ClientSession() as session: From e5297148f92d59d085781b4241279a27b5b6d4d4 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 20 Feb 2026 14:50:43 -0500 Subject: [PATCH 17/23] dynamic system prompt fixed --- atroposlib/envs/server_handling/teacher_client.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/atroposlib/envs/server_handling/teacher_client.py b/atroposlib/envs/server_handling/teacher_client.py index c79a6b83..f9f60fb0 100644 --- a/atroposlib/envs/server_handling/teacher_client.py +++ b/atroposlib/envs/server_handling/teacher_client.py @@ -278,6 +278,9 @@ class TeacherClient: context: Dict[str, Any], template: Optional[str], ) -> str: + system_prompt = self._normalize_multiline_text(system_prompt) + prefix_text = self._normalize_multiline_text(prefix_text) + template = self._normalize_multiline_text(template) base_parts: List[str] = [] if system_prompt: base_parts.append(f"System instruction:\n{system_prompt.strip()}\n") @@ -323,6 +326,15 @@ class TeacherClient: return f"{base}\n\n" if base else "" + def _normalize_multiline_text(self, value: Optional[str]) -> Optional[str]: + """ + Normalize common escaped newlines from CLI/YAML strings. + Keep other backslash sequences (e.g., \\boxed) intact. + """ + if value is None: + return None + return value.replace("\\r\\n", "\n").replace("\\n", "\n").replace("\\t", "\t") + def _prepare_template_vars( self, context: Dict[str, Any], From e8d0e748774ef841fdf6c5dbb9b5f9c47f0731ed Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Tue, 24 Feb 2026 12:16:00 -0500 Subject: [PATCH 18/23] gsm8k cleanup --- environments/gsm8k_server.py | 63 ++---------------------------------- 1 file changed, 3 insertions(+), 60 deletions(-) diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 0fa53431..59db58f6 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -235,12 +235,6 @@ class GSM8kEnv(BaseEnv): gold_answer = ( "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" ) - question_preview = item["question"].replace("\n", " ")[:120] - print( - f"[GSM8K_DEBUG] collect_start group_size={self.config.group_size} " - f"q={question_preview!r}", - flush=True, - ) # Important: do not force env tokenizer into ManagedServer for rollout. # Let ManagedServer use the server's tokenizer to keep prompt token IDs aligned. @@ -255,24 +249,10 @@ class GSM8kEnv(BaseEnv): state = managed.get_state() nodes = state["nodes"] - print( - f"[GSM8K_DEBUG] completion_batch_received choices={len(chat_completions.choices)} " - f"nodes={len(nodes)}", - flush=True, - ) to_score = list() to_backlog = list() for i, chat_completion in enumerate(chat_completions.choices): - response_text = chat_completion.message.content or "" - response_preview = response_text.replace("\n", " ")[:220] - valid_mask_count = sum(1 for m in nodes[i].masked_tokens if m != -100) - print( - f"[GSM8K_DEBUG] response_received idx={i} finish={chat_completion.finish_reason} " - f"tokens={len(nodes[i].tokens)} valid_masked={valid_mask_count} " - f"text={response_preview!r}", - flush=True, - ) messages = ( {"role": "system", "content": system_prompt}, user_message, @@ -289,11 +269,6 @@ class GSM8kEnv(BaseEnv): } ) to_postprocess = await self.score(to_score) - accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", [])) - print( - f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}", - flush=True, - ) return to_postprocess, to_backlog async def score( @@ -309,11 +284,6 @@ class GSM8kEnv(BaseEnv): extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) - print( - f"[GSM8K_DEBUG] score_start candidates={len(rollout_group_data)} " - f"gold_parsed_len={len(gold_parsed)}", - flush=True, - ) if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) random.shuffle(rollout_group_data) @@ -347,18 +317,7 @@ class GSM8kEnv(BaseEnv): # remove obviously bad examples valid_mask_count = len([1 for i in masks if i != -100]) - print( - f"[GSM8K_DEBUG] score_candidate idx={idx} parsed_len={len(answer_parsed)} " - f"reward={bool(reward)} valid_masked={valid_mask_count} " - f"tokens={len(tokens)}", - flush=True, - ) if valid_mask_count < 10: - print( - f"[GSM8K_DEBUG] drop_candidate idx={idx} reason=valid_masked_lt_10 " - f"value={valid_mask_count}", - flush=True, - ) continue scores["tokens"].append(tokens) scores["masks"].append(masks) @@ -372,10 +331,6 @@ class GSM8kEnv(BaseEnv): self.percent_correct_buffer.append(max(score, 0)) if len(scores["scores"]) == 0: - print( - "[GSM8K_DEBUG] drop_group reason=no_valid_candidates_after_filtering", - flush=True, - ) return None # check if all the same @@ -385,10 +340,6 @@ class GSM8kEnv(BaseEnv): token_lengths = [len(token) for token in scores["tokens"]] if max(token_lengths) == 0: # What? But don't want to crash a run so just in case... - print( - "[GSM8K_DEBUG] drop_group reason=zero_token_length_after_penalty_branch", - flush=True, - ) return None # Get max allowed token length from config @@ -411,21 +362,13 @@ class GSM8kEnv(BaseEnv): percentage_of_range = min(percentage_of_range, 1.0) # Apply linear penalty scaling from 1.0 down to 0.0 scores["scores"].append(1.0 - percentage_of_range) - if all([scores["scores"][0] == score for score in scores["scores"]]): - print( - f"[GSM8K_DEBUG] drop_group reason=all_scores_identical scores={scores['scores']}", - flush=True, - ) + if self.config.ensure_scores_are_not_same and all( + [scores["scores"][0] == score for score in scores["scores"]] + ): return None # If all the same, we return None - print( - f"[GSM8K_DEBUG] score_done accepted={len(scores['scores'])} " - f"scores={scores['scores']}", - flush=True, - ) return scores else: # If the gold solution is not parseable, we return None - print("[GSM8K_DEBUG] drop_group reason=gold_unparseable", flush=True) return None async def get_next_item(self) -> GSM8kRow: From f343b24a6a64008df0d6ff4676636bd3830da9de Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 27 Feb 2026 11:14:42 -0500 Subject: [PATCH 19/23] narrow down scope --- README.md | 185 +----- atroposlib/envs/base.py | 135 ----- .../envs/server_handling/teacher_client.py | 571 ------------------ 3 files changed, 22 insertions(+), 869 deletions(-) delete mode 100644 atroposlib/envs/server_handling/teacher_client.py diff --git a/README.md b/README.md index b0ca8d15..3b533a9b 100644 --- a/README.md +++ b/README.md @@ -256,188 +256,47 @@ Atropos repo contains an example trainer that should primarily be used as a refe To use the example trainer, see this page: [training example guide](example_trainer/README.md) -## On-Policy Distillation (Environment + API Flow) +## On-Policy Distillation (API + ScoredDataGroup Contract) -Atropos supports on-policy distillation by fetching teacher top-k token distributions for the same trajectories that are used for RL, then attaching that teacher data to the batches consumed by your trainer. +Atropos now supports OPD at the transport layer by carrying distillation arrays +through `ScoredDataGroup` and the API queue/batch endpoints. -### How the flow works +### Scope of this change -1. **Student rollouts are generated by the environment** - - Example: `environments/gsm8k_server.py` samples `group_size` completions from your student inference server. -2. **Environment scores and validates the group** - - Normal Atropos filtering still applies (group structure, optional score-equality checks, max token checks, etc.). -3. **Teacher logprobs are fetched in `BaseEnv.handle_send_to_api`** - - If `distillation_enabled=true` and `teacher_base_url` is set, Atropos calls the teacher endpoint and builds: - - `distill_token_ids` with shape `[sequence][position][top_k]` - - `distill_logprobs` with shape `[sequence][position][top_k]` -4. **Distillation arrays are attached to each scored group** - - Added as `group["distill_token_ids"]` and `group["distill_logprobs"]`. -5. **Atropos API stores and serves these fields unchanged** - - `/scored_data` and `/batch` include the distillation arrays. -6. **Trainer consumes both RL and distillation signals** - - Example trainer computes GRPO + distillation loss from the same batch. +- No teacher fetching/orchestration in `BaseEnv`. +- Environments or external pipelines are responsible for populating distillation arrays. +- API stores and returns those arrays unchanged. -### Configuration knobs in environments +### Distillation payload fields -Distillation is configured in `BaseEnvConfig` and available via CLI under `--env.*`: +Each scored group may include: -- `--env.distillation_enabled true` -- `--env.teacher_base_url http://localhost:8003/v1` -- `--env.teacher_model_name ` -- `--env.teacher_api_key ` (or `TEACHER_API_KEY`) -- `--env.teacher_top_k 20` -- Optional steering controls: - - `--env.teacher_prefix_text "..."` - - `--env.teacher_system_prompt "..."` - - `--env.teacher_prompt_template "Question: {question}\nAnswer: {answer}\n{episodes}"` +- `distill_token_ids`: shape `[sequence][position][top_k]` +- `distill_logprobs`: shape `[sequence][position][top_k]` -Template-first prompting is the recommended dynamic interface. +These fields are optional, and when present are forwarded from: -#### Dynamic system prompts (step by step) +- environment -> `/scored_data` or `/scored_data_list` +- API queue -> `/batch` -> trainer -1. **Define one reusable template at config time** - - Put stable policy/rules in `--env.teacher_prompt_template`. - - Use variable placeholders for runtime values, e.g. `{question}`, `{answer}`, `{episodes}`. - - -2. **Pass runtime variables from the environment** - - At scoring time, attach variables to: - - `group_overrides.teacher_prompt_context` for one context shared by the group, or - - `overrides[i].teacher_prompt_context` for per-sequence customization. - - Alias key `teacher_prompt_variables` is also accepted. - -3. **(Optional) swap template per turn** - - You can set `group_overrides.teacher_prompt_template` or `overrides[i].teacher_prompt_template` - to change template structure on specific turns/samples. - -4. **Understand precedence** - - Per-sequence (`overrides[i]`) > group-level (`group_overrides`) > env config defaults. - - This lets you define a strong default template while overriding only special cases. - -5. **Know what the teacher actually sees** - - Teacher prompt is built as: rendered steering prefix + current student sequence text. - - Distillation still aligns to student token positions after prefix trimming. - -#### Example template - -```text -You are a math teacher supervising a solution process. - -Hidden reference answer: -{answer} - -Rules: -1) Do not reveal the hidden reference directly. -2) Re-derive from first principles. -3) Give the final answer only after derivation in \boxed{{...}}. - -Question: -{question} -``` - -#### Example runtime injection (inside env scoring) +### Minimal producer example (environment side) ```python -scores["group_overrides"] = { - "teacher_prompt_context": { - "question": rollout_group_data[0]["messages"][1]["content"], - "answer": rollout_group_data[0]["gold_answer"], - } -} +scores["distill_token_ids"] = distill_token_ids +scores["distill_logprobs"] = distill_logprobs ``` -### Self-distillation vs cross-model distillation +### Minimal consumer check (trainer/debug side) -Both setups are supported: - -- **Self-distillation (same model family for teacher and student)** - Point `teacher_base_url` to a server running the same model (or equivalent checkpoint family) as the student. This is the most stable setup for token-level alignment. - -- **Cross-model distillation (different teacher and student models)** - Also supported, but tokenization compatibility becomes more important. If token vocabularies/template behavior differ significantly, alignment quality may degrade. - -In practice, self-distillation is usually easiest to bring up first, then cross-model can be layered in once your pipeline is stable. - -### Tokenization and alignment details - -Atropos handles tokenization in two places: - -1. **Student rollout path (`server_type=vllm`)** - - The `/generate` request is built via the vLLM server handler and uses the server-side tokenizer configured by: - - `--openai.tokenizer_name` (or falls back to `--openai.model_name`) - - Recommendation: set `--openai.tokenizer_name` explicitly to match the student serving model. - -2. **Teacher top-k parsing path** - - Teacher responses are fetched/parsed in `TeacherClient.get_teacher_logprobs` (called by `BaseEnv`). - - The parser maps teacher token strings into ids using the environment tokenizer (`self.tokenizer`) and then aligns to student sequence length. - -Because distillation is token-position based, keeping tokenizer families compatible is strongly recommended, especially for cross-model distillation. - -### Minimal bring-up example - -Run each command in a separate terminal. - -1. **Start Atropos API** -```bash -run-api --port 8002 -``` - -2. **Start teacher server (OpenAI-compatible endpoint)** -```bash -python -m vllm.entrypoints.openai.api_server \ - --model "$TEACHER_MODEL" \ - --host 0.0.0.0 \ - --port 8003 -``` - -3. **Start student server for environments (`/generate` endpoint)** -```bash -python -m example_trainer.vllm_api_server \ - --model "$STUDENT_MODEL" \ - --port 9001 -``` - -4. **Start environment with distillation enabled** -```bash -python environments/gsm8k_server.py serve \ - --env.rollout_server_url "http://localhost:8002" \ - --env.distillation_enabled true \ - --env.teacher_base_url "http://localhost:8003/v1" \ - --env.teacher_model_name "$TEACHER_MODEL" \ - --env.teacher_top_k 20 \ - --openai.server_type vllm \ - --openai.base_url "http://localhost:9001/v1" \ - --openai.model_name "$STUDENT_MODEL" \ - --openai.tokenizer_name "$STUDENT_MODEL" -``` - -5. **Start trainer with distillation flags** -```bash -python -m example_trainer.grpo \ - --atropos-url "http://localhost:8002" \ - --distillation-enabled \ - --distillation-coef 0.1 \ - --distillation-loss-type kl \ - --distillation-temperature 1.0 -``` - -### Verification checklist - -- Environment logs show distillation fetch: - - `[DISTILL] Fetching teacher logprobs ...` - - `[DISTILL] Added teacher distill arrays ...` -- Teacher logs show completion/chat requests from the environment. -- API contains distill fields in latest example: ```bash curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!=null), has_lps:(.distill_logprobs!=null)}' ``` -- Trainer logs report distillation metrics (example trainer): - - `Distill: loss=...` -### Important notes +### Notes -- For `server_type=vllm`, the environment expects a server exposing `/generate` (the custom server in `example_trainer/vllm_api_server.py`), not only `/v1/chat/completions`. -- Prefer explicitly setting `--openai.tokenizer_name` to your student tokenizer to avoid prompt token-ID mismatch. +- The API does not validate cross-field semantics beyond schema typing. +- Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.). +- Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR. --- diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index f49dc3d5..38b5965f 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -49,7 +49,6 @@ from .server_handling.server_manager import ( ServerManager, ServerManagerConfig, ) -from .server_handling.teacher_client import TeacherClient logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -212,51 +211,6 @@ class BaseEnvConfig(BaseModel): "no thinking prompt is injected. Use HERMES_REASONING_PROMPT from " "eval_helpers for the standard Hermes reasoning prompt.", ) - # On-policy distillation settings - distillation_enabled: bool = Field( - default=False, - description="Enable on-policy distillation. When True, automatically fetches teacher logprobs " - "after scoring and includes them in data sent to trainer.", - ) - teacher_base_url: Optional[str] = Field( - default=None, - description="Base URL of teacher model for distillation. Supports any OpenAI-compatible API " - "(vLLM, OpenAI, Together, etc.). Examples: 'http://localhost:8001/v1', 'https://api.openai.com/v1'", - ) - teacher_model_name: Optional[str] = Field( - default=None, - description="Model name for teacher API calls (e.g., 'gpt-4o', 'meta-llama/Llama-3-70b'). " - "If None, uses 'default' which works for single-model vLLM servers.", - ) - teacher_api_key: Optional[str] = Field( - default=None, - description="API key for teacher model. Can also be set via TEACHER_API_KEY env var.", - ) - teacher_top_k: int = Field( - default=20, - description="Number of top logprobs to fetch from teacher model per position.", - ) - teacher_prefix_text: Optional[str] = Field( - default=None, - description="Optional text prefix prepended to teacher scoring prompt. " - "Useful for behavior steering. Prefix token positions are trimmed out " - "before sending distillation arrays to the trainer, preserving alignment.", - ) - teacher_system_prompt: Optional[str] = Field( - default=None, - description="Optional teacher system prompt. For completion-style teacher APIs, " - "this is converted to a textual prefix. For chat fallback, this is injected " - "as a leading system message.", - ) - teacher_prompt_template: Optional[str] = Field( - default=None, - description="Optional template-first teacher prompt renderer. " - "Uses Python format-style variables from runtime context/overrides " - "(e.g., {question}, {answer}, {episodes}). If set, this is preferred over " - "mode-specific prompt building.", - ) - - class BaseEnv(ABC): name: Optional[str] = None env_config_cls: BaseEnvConfig = BaseEnvConfig @@ -305,9 +259,6 @@ class BaseEnv(ABC): self.curr_step = 0 self.max_token_len = -1 self.tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name) - self.teacher_client = TeacherClient( - config=self.config, tokenizer=self.tokenizer, logger=logger - ) self.completion_lengths = [] self.max_num_workers = config.max_num_workers if self.max_num_workers == -1: @@ -363,46 +314,6 @@ class BaseEnv(ABC): # Calculate derived batch size return int(self.config.batch_size * effective_fraction) - async def get_teacher_logprobs( - self, - token_sequences: List[List[int]], - messages_list: Optional[List[List[Dict]]] = None, - seq_overrides: Optional[List[Dict[str, Any]]] = None, - group_overrides: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None, - ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: - return await self.teacher_client.get_teacher_logprobs( - token_sequences=token_sequences, - messages_list=messages_list, - seq_overrides=seq_overrides, - group_overrides=group_overrides, - top_k=top_k, - ) - - def _align_teacher_topk_to_tokens( - self, - seq_token_ids: List[List[int]], - seq_logprobs: List[List[float]], - target_token_len: int, - prefix_token_len: int = 0, - ) -> Tuple[List[List[int]], List[List[float]]]: - return self.teacher_client._align_teacher_topk_to_tokens( - seq_token_ids=seq_token_ids, - seq_logprobs=seq_logprobs, - target_token_len=target_token_len, - prefix_token_len=prefix_token_len, - ) - - def _parse_completion_logprobs( - self, data: Dict, top_k: int - ) -> Tuple[List[List[int]], List[List[float]]]: - return self.teacher_client._parse_completion_logprobs(data=data, top_k=top_k) - - def _parse_chat_logprobs( - self, data: Dict, top_k: int - ) -> Tuple[List[List[int]], List[List[float]]]: - return self.teacher_client._parse_chat_logprobs(data=data, top_k=top_k) - @classmethod def config_init( cls, @@ -1019,52 +930,6 @@ class BaseEnv(ABC): valid_groups.append(group) if valid_groups and do_send_to_api: - # On-policy distillation: fetch teacher logprobs if enabled - if self.config.distillation_enabled and self.config.teacher_base_url: - logger.info(f"[DISTILL] Fetching teacher logprobs for {len(valid_groups)} groups") - for group in valid_groups: - seq_overrides = group.get("overrides") or [] - group_overrides = ( - group.get("group_overrides") - if isinstance(group.get("group_overrides"), dict) - else {} - ) - has_new_format = ( - group.get("distill_token_ids") is not None - and group.get("distill_logprobs") is not None - ) - if not has_new_format: - try: - teacher_token_ids, teacher_logprobs = await self.get_teacher_logprobs( - token_sequences=group["tokens"], - messages_list=group.get("messages"), - seq_overrides=seq_overrides, - group_overrides=group_overrides, - ) - if teacher_token_ids and teacher_logprobs: - group["distill_token_ids"] = teacher_token_ids - group["distill_logprobs"] = teacher_logprobs - logger.info( - f"[DISTILL] Added teacher distill arrays for {len(teacher_token_ids)} sequences" - ) - else: - logger.warning("[DISTILL] get_teacher_logprobs returned empty") - except Exception as e: - logger.error(f"[DISTILL] Failed to fetch teacher logprobs: {e}") - import traceback - logger.error(traceback.format_exc()) - self.teacher_client.assert_distill_arrays_aligned( - token_sequences=group["tokens"], - distill_token_ids=group.get("distill_token_ids"), - distill_logprobs=group.get("distill_logprobs"), - ) - else: - logger.debug( - "[DISTILL] Skipped - enabled=%s, url=%s", - self.config.distillation_enabled, - self.config.teacher_base_url, - ) - data_to_send_to_api: Union[ScoredDataGroup, List[ScoredDataGroup]] # send single or list of scored data groups if not original_was_list and len(valid_groups) == 1: diff --git a/atroposlib/envs/server_handling/teacher_client.py b/atroposlib/envs/server_handling/teacher_client.py deleted file mode 100644 index f9f60fb0..00000000 --- a/atroposlib/envs/server_handling/teacher_client.py +++ /dev/null @@ -1,571 +0,0 @@ -import json -import os -from typing import Any, Dict, List, Optional, Tuple - -import aiohttp - - -class TeacherClient: - """ - Transport/parsing client for teacher top-k logprobs. - - This keeps distillation HTTP and parsing logic out of BaseEnv. - """ - - def __init__(self, config, tokenizer, logger): - self.config = config - self.tokenizer = tokenizer - self.logger = logger - if self.config.distillation_enabled: - self._validate_distillation_config() - - def _validate_distillation_config(self) -> None: - if not self.config.teacher_base_url: - raise ValueError("Distillation requires `teacher_base_url` to be set.") - if self.config.teacher_top_k <= 0: - raise ValueError( - f"Distillation requires `teacher_top_k > 0`, got {self.config.teacher_top_k}." - ) - student_model_name = getattr(self.config, "model_name", None) - if ( - self.config.teacher_model_name - and student_model_name - and self.config.teacher_model_name != student_model_name - ): - self.logger.warning( - "Cross-model distillation configured (teacher=%s, student=%s). " - "Token-level alignment quality depends on tokenizer compatibility.", - self.config.teacher_model_name, - student_model_name, - ) - - async def get_teacher_logprobs( - self, - token_sequences: List[List[int]], - messages_list: Optional[List[List[Dict]]] = None, - seq_overrides: Optional[List[Dict[str, Any]]] = None, - group_overrides: Optional[Dict[str, Any]] = None, - top_k: Optional[int] = None, - ) -> Tuple[List[List[List[int]]], List[List[List[float]]]]: - self.logger.info( - "[TEACHER] get_teacher_logprobs called with %s sequences", - len(token_sequences), - ) - self.logger.info("[TEACHER] teacher_base_url=%s", self.config.teacher_base_url) - - if not self.config.teacher_base_url: - self.logger.warning("[TEACHER] No teacher_base_url configured, returning empty") - return [], [] - - if top_k is None: - top_k = self.config.teacher_top_k - - api_key = self.config.teacher_api_key or os.environ.get("TEACHER_API_KEY", "") - model_name = self.config.teacher_model_name or "default" - self.logger.info("[TEACHER] Using model=%s, top_k=%s", model_name, top_k) - - headers = {"Content-Type": "application/json"} - if api_key: - headers["Authorization"] = f"Bearer {api_key}" - - token_id_results: List[List[List[int]]] = [] - logprob_results: List[List[List[float]]] = [] - - try: - async with aiohttp.ClientSession() as session: - for i, tokens in enumerate(token_sequences): - self.logger.info( - "[TEACHER] Processing sequence %s/%s, %s tokens", - i + 1, - len(token_sequences), - len(tokens), - ) - base_text = self.tokenizer.decode(tokens, skip_special_tokens=False) - ( - seq_system_prompt, - seq_prefix_text, - seq_prompt_mode, - seq_prompt_context, - seq_prompt_template, - ) = self._resolve_prompt_overrides_for_sequence( - seq_idx=i, - seq_overrides=seq_overrides, - group_overrides=group_overrides, - ) - steering_prefix = self._build_teacher_steering_prefix( - system_prompt=seq_system_prompt, - prefix_text=seq_prefix_text, - mode=seq_prompt_mode, - context=seq_prompt_context, - template=seq_prompt_template, - ) - full_text = steering_prefix + base_text - prefix_token_len = ( - len( - self.tokenizer.encode( - steering_prefix, add_special_tokens=False - ) - ) - if steering_prefix - else 0 - ) - - request_data = { - "model": model_name, - "prompt": full_text, - "max_tokens": 1, - "temperature": 1.0, - "logprobs": top_k, - "echo": True, - } - try: - async with session.post( - f"{self.config.teacher_base_url}/completions", - json=request_data, - headers=headers, - timeout=aiohttp.ClientTimeout(total=120), - ) as response: - if response.status == 200: - data = await response.json() - seq_token_ids, seq_logprobs = self._parse_completion_logprobs( - data, top_k - ) - if seq_token_ids and seq_logprobs: - aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( - seq_token_ids, - seq_logprobs, - target_token_len=len(tokens), - prefix_token_len=prefix_token_len, - ) - aligned_ids, aligned_lps = self._normalize_aligned_rows( - aligned_ids, aligned_lps, top_k - ) - token_id_results.append(aligned_ids) - logprob_results.append(aligned_lps) - continue - except Exception: - pass - - if messages_list and i < len(messages_list): - messages = self._normalize_messages(messages_list[i], full_text) - if seq_system_prompt: - messages = [ - { - "role": "system", - "content": seq_system_prompt, - } - ] + messages - else: - messages = [] - if seq_system_prompt: - messages.append( - { - "role": "system", - "content": seq_system_prompt, - } - ) - messages.append({"role": "user", "content": full_text}) - - chat_request = { - "model": model_name, - "messages": messages, - "max_tokens": 1, - "temperature": 1.0, - "logprobs": True, - "top_logprobs": top_k, - } - try: - async with session.post( - f"{self.config.teacher_base_url}/chat/completions", - json=chat_request, - headers=headers, - timeout=aiohttp.ClientTimeout(total=120), - ) as response: - if response.status == 200: - data = await response.json() - seq_token_ids, seq_logprobs = self._parse_chat_logprobs( - data, top_k - ) - if seq_token_ids and len(seq_token_ids) >= len(tokens): - aligned_ids, aligned_lps = self._align_teacher_topk_to_tokens( - seq_token_ids, - seq_logprobs, - target_token_len=len(tokens), - prefix_token_len=0, - ) - aligned_ids, aligned_lps = self._normalize_aligned_rows( - aligned_ids, aligned_lps, top_k - ) - else: - aligned_ids = [[] for _ in range(len(tokens))] - aligned_lps = [[] for _ in range(len(tokens))] - token_id_results.append(aligned_ids) - logprob_results.append(aligned_lps) - else: - self.logger.warning( - "Teacher API returned %s", response.status - ) - token_id_results.append([[] for _ in range(len(tokens))]) - logprob_results.append([[] for _ in range(len(tokens))]) - except Exception as e: - self.logger.warning("Teacher chat request failed: %s", e) - token_id_results.append([[] for _ in range(len(tokens))]) - logprob_results.append([[] for _ in range(len(tokens))]) - - return token_id_results, logprob_results - except Exception as e: - self.logger.error("Error fetching teacher logprobs: %s", e) - return [], [] - - def _resolve_prompt_overrides_for_sequence( - self, - seq_idx: int, - seq_overrides: Optional[List[Dict[str, Any]]], - group_overrides: Optional[Dict[str, Any]], - ) -> Tuple[Optional[str], Optional[str], str, Dict[str, Any], Optional[str]]: - group_overrides = group_overrides or {} - seq_override: Dict[str, Any] = {} - if ( - seq_overrides is not None - and seq_idx < len(seq_overrides) - and isinstance(seq_overrides[seq_idx], dict) - ): - seq_override = seq_overrides[seq_idx] - - seq_system_prompt = seq_override.get( - "teacher_system_prompt", - group_overrides.get("teacher_system_prompt", self.config.teacher_system_prompt), - ) - seq_prefix_text = seq_override.get( - "teacher_prefix_text", - group_overrides.get("teacher_prefix_text", self.config.teacher_prefix_text), - ) - seq_prompt_mode = seq_override.get( - "teacher_prompt_mode", - group_overrides.get("teacher_prompt_mode", "default"), - ) - # `teacher_prompt_variables` is accepted as an alias for template-style usage. - seq_prompt_context = seq_override.get("teacher_prompt_context") - if seq_prompt_context is None: - seq_prompt_context = seq_override.get("teacher_prompt_variables") - if seq_prompt_context is None: - seq_prompt_context = group_overrides.get("teacher_prompt_context") - if seq_prompt_context is None: - seq_prompt_context = group_overrides.get("teacher_prompt_variables") - if seq_prompt_context is None: - seq_prompt_context = {} - if not isinstance(seq_prompt_context, dict): - seq_prompt_context = {} - seq_prompt_template = seq_override.get( - "teacher_prompt_template", - group_overrides.get( - "teacher_prompt_template", self.config.teacher_prompt_template - ), - ) - return ( - seq_system_prompt, - seq_prefix_text, - seq_prompt_mode, - seq_prompt_context, - seq_prompt_template, - ) - - def _build_teacher_steering_prefix( - self, - system_prompt: Optional[str], - prefix_text: Optional[str], - mode: str, - context: Dict[str, Any], - template: Optional[str], - ) -> str: - system_prompt = self._normalize_multiline_text(system_prompt) - prefix_text = self._normalize_multiline_text(prefix_text) - template = self._normalize_multiline_text(template) - base_parts: List[str] = [] - if system_prompt: - base_parts.append(f"System instruction:\n{system_prompt.strip()}\n") - if prefix_text: - base_parts.append(str(prefix_text)) - base = "\n".join(x for x in base_parts if x).strip() - - normalized_mode = (mode or "default").strip().lower() - ctx = context or {} - - # Template-first path (recommended): render once with runtime variables. - if template: - template_vars = self._prepare_template_vars( - context=ctx, system_prompt=system_prompt, prefix_text=prefix_text - ) - try: - rendered = template.format_map(_SafeFormatDict(template_vars)) - except Exception: - rendered = template - return f"{rendered}\n\n" if rendered else "" - - if normalized_mode == "answer_context": - answer = ctx.get("answer") - if answer: - return ( - f"{base}\n\nReference answer/context:\n{answer}\n\n" - if base - else f"Reference answer/context:\n{answer}\n\n" - ) - return f"{base}\n\n" if base else "" - - if normalized_mode == "history_context": - episodes = ctx.get("episodes") - if isinstance(episodes, list) and episodes: - episode_lines = [f"Episode {idx + 1}: {ep}" for idx, ep in enumerate(episodes)] - history_block = "\n".join(episode_lines) - return ( - f"{base}\n\nPrevious episodes:\n{history_block}\n\n" - if base - else f"Previous episodes:\n{history_block}\n\n" - ) - return f"{base}\n\n" if base else "" - - return f"{base}\n\n" if base else "" - - def _normalize_multiline_text(self, value: Optional[str]) -> Optional[str]: - """ - Normalize common escaped newlines from CLI/YAML strings. - Keep other backslash sequences (e.g., \\boxed) intact. - """ - if value is None: - return None - return value.replace("\\r\\n", "\n").replace("\\n", "\n").replace("\\t", "\t") - - def _prepare_template_vars( - self, - context: Dict[str, Any], - system_prompt: Optional[str], - prefix_text: Optional[str], - ) -> Dict[str, Any]: - """ - Build template variables with convenience aliases for common dynamic fields. - """ - template_vars = dict(context) - template_vars.setdefault("system_prompt", system_prompt or "") - template_vars.setdefault("prefix_text", prefix_text or "") - template_vars.setdefault("answer", context.get("answer", "")) - template_vars.setdefault("question", context.get("question", "")) - episodes = context.get("episodes") - if isinstance(episodes, list): - template_vars.setdefault( - "episodes", - "\n".join(f"Episode {idx + 1}: {ep}" for idx, ep in enumerate(episodes)), - ) - template_vars.setdefault("episodes_json", json.dumps(episodes, ensure_ascii=True)) - else: - template_vars.setdefault("episodes", "") - template_vars.setdefault("episodes_json", "[]") - return template_vars - - def _normalize_aligned_rows( - self, - seq_token_ids: List[List[int]], - seq_logprobs: List[List[float]], - top_k: int, - ) -> Tuple[List[List[int]], List[List[float]]]: - """ - Enforce per-position alignment invariants: - - same number of positions in ids and logprobs - - same number of top-k entries per position - - cap each position to <= top_k - """ - normalized_ids: List[List[int]] = [] - normalized_lps: List[List[float]] = [] - n_positions = max(len(seq_token_ids), len(seq_logprobs)) - for pos in range(n_positions): - ids = ( - seq_token_ids[pos] - if pos < len(seq_token_ids) and isinstance(seq_token_ids[pos], list) - else [] - ) - lps = ( - seq_logprobs[pos] - if pos < len(seq_logprobs) and isinstance(seq_logprobs[pos], list) - else [] - ) - n = min(len(ids), len(lps), top_k) - normalized_ids.append([int(x) for x in ids[:n]]) - normalized_lps.append([float(x) for x in lps[:n]]) - return normalized_ids, normalized_lps - - def assert_distill_arrays_aligned( - self, - token_sequences: List[List[int]], - distill_token_ids: Optional[List[List[List[int]]]], - distill_logprobs: Optional[List[List[List[float]]]], - ) -> None: - """ - Strict OPD invariant checks: - - both arrays exist - - one sequence row per token sequence - - one position row per token position - - ids/logprobs top-k row lengths match at each position - """ - if distill_token_ids is None or distill_logprobs is None: - raise AssertionError( - "[DISTILL] distill_token_ids/distill_logprobs must both be present." - ) - - if len(distill_token_ids) != len(token_sequences) or len(distill_logprobs) != len( - token_sequences - ): - raise AssertionError( - "[DISTILL] sequence count mismatch: " - f"tokens={len(token_sequences)} ids={len(distill_token_ids)} " - f"lps={len(distill_logprobs)}" - ) - - for seq_idx, tokens in enumerate(token_sequences): - expected_positions = len(tokens) - seq_ids = distill_token_ids[seq_idx] - seq_lps = distill_logprobs[seq_idx] - - if len(seq_ids) != expected_positions or len(seq_lps) != expected_positions: - raise AssertionError( - "[DISTILL] position count mismatch at seq " - f"{seq_idx}: tokens={expected_positions} ids={len(seq_ids)} " - f"lps={len(seq_lps)}" - ) - - for pos_idx in range(expected_positions): - if len(seq_ids[pos_idx]) != len(seq_lps[pos_idx]): - raise AssertionError( - "[DISTILL] top-k row mismatch at " - f"seq={seq_idx}, pos={pos_idx}: " - f"ids={len(seq_ids[pos_idx])}, lps={len(seq_lps[pos_idx])}" - ) - - def _normalize_messages( - self, raw_messages: Any, fallback_text: str - ) -> List[Dict[str, str]]: - """ - Normalize environment message payloads for chat/completions teacher fallback. - - Accepts already-structured message lists, plain strings, or unknown structures. - """ - if isinstance(raw_messages, str): - return [{"role": "user", "content": raw_messages}] - - if isinstance(raw_messages, list): - normalized: List[Dict[str, str]] = [] - for msg in raw_messages: - if ( - isinstance(msg, dict) - and "role" in msg - and "content" in msg - and isinstance(msg["content"], str) - ): - normalized.append( - {"role": str(msg["role"]), "content": msg["content"]} - ) - elif isinstance(msg, str): - normalized.append({"role": "user", "content": msg}) - if normalized: - return normalized - - return [{"role": "user", "content": fallback_text}] - - def _align_teacher_topk_to_tokens( - self, - seq_token_ids: List[List[int]], - seq_logprobs: List[List[float]], - target_token_len: int, - prefix_token_len: int = 0, - ) -> Tuple[List[List[int]], List[List[float]]]: - n = min(len(seq_token_ids), len(seq_logprobs)) - aligned_ids = list(seq_token_ids[:n]) - aligned_lps = list(seq_logprobs[:n]) - - if prefix_token_len > 0: - aligned_ids = aligned_ids[prefix_token_len:] - aligned_lps = aligned_lps[prefix_token_len:] - - aligned_ids = aligned_ids[:target_token_len] - aligned_lps = aligned_lps[:target_token_len] - - if len(aligned_ids) < target_token_len: - pad_count = target_token_len - len(aligned_ids) - aligned_ids.extend([[] for _ in range(pad_count)]) - aligned_lps.extend([[] for _ in range(pad_count)]) - - return aligned_ids, aligned_lps - - def _parse_completion_logprobs( - self, data: Dict, top_k: int - ) -> Tuple[List[List[int]], List[List[float]]]: - try: - choice = data.get("choices", [{}])[0] - logprobs_data = choice.get("logprobs", {}) - top_logprobs = logprobs_data.get("top_logprobs", []) - if not top_logprobs: - return [], [] - - seq_token_ids: List[List[int]] = [] - seq_logprobs: List[List[float]] = [] - for pos_logprobs in top_logprobs: - if pos_logprobs is None: - seq_token_ids.append([]) - seq_logprobs.append([]) - elif isinstance(pos_logprobs, dict): - sorted_items = sorted( - pos_logprobs.items(), key=lambda x: x[1], reverse=True - )[:top_k] - pos_ids: List[int] = [] - pos_lps: List[float] = [] - for token_str, logprob in sorted_items: - token_ids = self.tokenizer.encode( - token_str, add_special_tokens=False - ) - if token_ids: - pos_ids.append(int(token_ids[0])) - pos_lps.append(float(logprob)) - seq_token_ids.append(pos_ids) - seq_logprobs.append(pos_lps) - else: - seq_token_ids.append([]) - seq_logprobs.append([]) - return seq_token_ids, seq_logprobs - except Exception as e: - self.logger.warning("Error parsing completion logprobs: %s", e) - return [], [] - - def _parse_chat_logprobs( - self, data: Dict, top_k: int - ) -> Tuple[List[List[int]], List[List[float]]]: - try: - choice = data.get("choices", [{}])[0] - logprobs_data = choice.get("logprobs", {}) - if not logprobs_data: - return [], [] - - content = logprobs_data.get("content", []) - seq_token_ids: List[List[int]] = [] - seq_logprobs: List[List[float]] = [] - for token_data in content: - top_logprobs = token_data.get("top_logprobs", []) - pos_ids: List[int] = [] - pos_lps: List[float] = [] - for item in top_logprobs[:top_k]: - token_str = item.get("token", "") - logprob = item.get("logprob", 0.0) - token_ids = self.tokenizer.encode( - token_str, add_special_tokens=False - ) - if token_ids: - pos_ids.append(int(token_ids[0])) - pos_lps.append(float(logprob)) - seq_token_ids.append(pos_ids) - seq_logprobs.append(pos_lps) - return seq_token_ids, seq_logprobs - except Exception as e: - self.logger.warning("Error parsing chat logprobs: %s", e) - return [], [] - - -class _SafeFormatDict(dict): - def __missing__(self, key): - return "{" + key + "}" From 836c346406caad6ba7e1db540382985fb1d6e97b Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 27 Feb 2026 13:15:23 -0500 Subject: [PATCH 20/23] narrow down scope further --- environments/gsm8k_server.py | 22 +++++----------------- environments/math_server_zero.py | 10 +++++++--- example_trainer/README.md | 10 ++++++++++ 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 59db58f6..6ae5285b 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -123,11 +123,7 @@ class GSM8kEnv(BaseEnv): async def rollout_and_score_eval(self, question: str, answer: str) -> dict: """Rollout and score evaluation with detailed sample data collection.""" - # Important: use ManagedServer's default tokenizer resolution so it uses - # the underlying inference server tokenizer (e.g., Qwen) instead of the - # environment tokenizer. Passing self.tokenizer here can cause token-ID - # mismatch and gibberish generations when model/tokenizer families differ. - async with self.server.managed_server() as managed: + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: completion = await managed.chat_completion( messages=[ {"role": "system", "content": system_prompt}, @@ -236,9 +232,7 @@ class GSM8kEnv(BaseEnv): "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" ) - # Important: do not force env tokenizer into ManagedServer for rollout. - # Let ManagedServer use the server's tokenizer to keep prompt token IDs aligned. - async with self.server.managed_server() as managed: + async with self.server.managed_server(tokenizer=self.tokenizer) as managed: chat_completions = await managed.chat_completion( messages=[{"role": "system", "content": system_prompt}, user_message], @@ -287,7 +281,7 @@ class GSM8kEnv(BaseEnv): if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) random.shuffle(rollout_group_data) - for idx, item in enumerate(rollout_group_data): + for item in rollout_group_data: # print(item[0][-1]["content"]) answer_parsed = parse( item["messages"][-1]["content"].split("")[-1], @@ -316,8 +310,7 @@ class GSM8kEnv(BaseEnv): logprobs = item["logprobs"] # remove obviously bad examples - valid_mask_count = len([1 for i in masks if i != -100]) - if valid_mask_count < 10: + if len([1 for i in masks if i != -100]) < 10: continue scores["tokens"].append(tokens) scores["masks"].append(masks) @@ -330,9 +323,6 @@ class GSM8kEnv(BaseEnv): for score in scores["scores"]: self.percent_correct_buffer.append(max(score, 0)) - if len(scores["scores"]) == 0: - return None - # check if all the same # print(scores['scores']) if all([score == 1 for score in scores["scores"]]): @@ -362,9 +352,7 @@ class GSM8kEnv(BaseEnv): percentage_of_range = min(percentage_of_range, 1.0) # Apply linear penalty scaling from 1.0 down to 0.0 scores["scores"].append(1.0 - percentage_of_range) - if self.config.ensure_scores_are_not_same and all( - [scores["scores"][0] == score for score in scores["scores"]] - ): + if all([scores["scores"][0] == score for score in scores["scores"]]): return None # If all the same, we return None return scores else: diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 12076dab..1432ab4d 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -11,7 +11,6 @@ from typing import Dict, List, Optional, Tuple import wandb from datasets import load_dataset - from latex2sympy2_extended import NormalizationConfig from math_verify import LatexExtractionConfig, parse, verify from math_verify.errors import TimeoutException @@ -124,6 +123,8 @@ class MathEnv(BaseEnv): slurm=True, testing=False, ): + print("Initializing MathEnv") + print(f"Slurm: {slurm}, Testing: {testing}") super().__init__(config, server_configs, slurm, testing) self.percent_correct_buffer = list() self.eval_metrics = list() @@ -396,6 +397,7 @@ class MathEnv(BaseEnv): ) if len(self.normal_rollouts) > self.config.num_rollouts_to_keep: self.normal_rollouts.pop(0) + print(f"Collected {len(to_postprocess['scores'])} trajectories") return to_postprocess, to_backlog async def score(self, rollout_group_data: List) -> Optional[ScoredDataGroup]: @@ -480,7 +482,6 @@ class MathEnv(BaseEnv): and (not scores["overrides"][i].get("set_advantage_to_zero", False)) ] ) - return scores async def get_next_item(self): @@ -496,7 +497,10 @@ class MathEnv(BaseEnv): ) break except TypeError: - continue + print( + f"Error in getting next item, trying again, " + f"data: {next_item['question']} -> {next_item['final_answer']}" + ) return (prompt, answer, "normal") diff --git a/example_trainer/README.md b/example_trainer/README.md index aee1831d..34c7b7a7 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -8,6 +8,16 @@ This example uses `vLLM` for efficient inference during the (simulated) data gen **Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training. +## On-Policy Distillation Scope + +The current OPD integration in Atropos is transport-only: + +- `ScoredDataGroup` / API payloads support `distill_token_ids` and `distill_logprobs`. +- Atropos API stores and returns those fields through `/scored_data` and `/batch`. +- Teacher orchestration (teacher endpoint calls, prompt rendering, top-k fetching) is intentionally out of scope in this PR. + +If you train with distillation, provide the two distill arrays from your environment or external data pipeline before posting to the API. + ### Custom vLLM Server The `vllm_api_server.py` file in this directory provides a customized vLLM API server implementation based on vLLM's native API. This server exposes enhanced endpoints for token and logprob tracking. The `VLLMServer` class in `atroposlib/envs/server_handling/vllm_server.py` can connect to this server for direct access to vLLM's `/generate` endpoint with full token-level logprobs. From 64d3ee1bd648a364aadfce4b25e1e449dbbf2904 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 18:16:02 +0000 Subject: [PATCH 21/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- atroposlib/envs/base.py | 16 +++++++++++----- atroposlib/envs/server_handling/openai_server.py | 4 +++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 38b5965f..99a8c4e9 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -211,6 +211,8 @@ class BaseEnvConfig(BaseModel): "no thinking prompt is injected. Use HERMES_REASONING_PROMPT from " "eval_helpers for the standard Hermes reasoning prompt.", ) + + class BaseEnv(ABC): name: Optional[str] = None env_config_cls: BaseEnvConfig = BaseEnvConfig @@ -1436,13 +1438,13 @@ class BaseEnv(ABC): cli_passed_flags, openai_full_prefix ) # CLI args yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) - + # Debug logging for CLI args print(f"[CLI DEBUG] cli_passed_flags = {cli_passed_flags}") print(f"[CLI DEBUG] openai_full_prefix = {openai_full_prefix}") print(f"[CLI DEBUG] oai_cli_passed_args = {oai_cli_passed_args}") print(f"[CLI DEBUG] yaml_oai_config = {yaml_oai_config}") - + # Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided # This allows any environment to use --openai.* CLI args without modifying config_init # Use a new variable to avoid UnboundLocalError from closure scoping @@ -1456,7 +1458,7 @@ class BaseEnv(ABC): logger.info( "Auto-converted ServerBaseline to APIServerConfig for CLI/YAML overrides" ) - + if ( isinstance(effective_server_configs, list) and len(effective_server_configs) == 1 @@ -1470,13 +1472,17 @@ class BaseEnv(ABC): if isinstance(default_openai_config_, APIServerConfig) and isinstance( yaml_oai_config, dict ): - print(f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}") + print( + f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}" + ) openai_config_dict = merge_dicts( default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) yaml_oai_config, oai_cli_passed_args, ) - print(f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}") + print( + f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}" + ) else: print( "[CLI DEBUG] Not merging: default_openai_config_ " diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index 24582273..fecc5828 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -200,7 +200,9 @@ def resolve_openai_configs( ) from e elif isinstance(default_server_configs, APIServerConfig): # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline - logger.info("Using single OpenAI server configuration based on merged settings (default/YAML/CLI).") + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) try: final_openai_config = APIServerConfig(**openai_config_dict) except Exception as e: From 35587cbdc05e19d74ab4f12f641e9cfd732c7233 Mon Sep 17 00:00:00 2001 From: Jai Suphavadeeprasit Date: Fri, 27 Feb 2026 16:13:51 -0500 Subject: [PATCH 22/23] logger changes --- atroposlib/api/server.py | 7 +- atroposlib/envs/base.py | 89 ++++++++++--------- .../envs/server_handling/managed_server.py | 19 ++-- 3 files changed, 63 insertions(+), 52 deletions(-) diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 978a6f25..b9327b75 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -1,4 +1,5 @@ import gzip +import logging import time import uuid from typing import Any, Dict, List, Optional @@ -25,6 +26,7 @@ MIN_ENV_WEIGHT = ( # Message import removed - using Dict[str, Any] for more flexible validation app = FastAPI(title="AtroposLib API") +logger = logging.getLogger(__name__) app.add_middleware( CORSMiddleware, @@ -391,7 +393,10 @@ async def get_batch(): app.state.curr_batch.append(batch) curr_batch = app.state.curr_batch.pop() # check length before sending - print(f"Sending batch of {sum(len(x['tokens']) for x in curr_batch)} sequences") + logger.info( + "Sending batch of %s sequences", + sum(len(x["tokens"]) for x in curr_batch), + ) return {"batch": curr_batch} diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 99a8c4e9..a466249c 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -286,7 +286,7 @@ class BaseEnv(ABC): counter += 1 path_changed = True if path_changed: - print( + logger.info( f"Changed data path to {self.config.data_path_to_save_groups} because {original_path} already exists." # noqa: E501 ) @@ -364,7 +364,7 @@ class BaseEnv(ABC): to_postprocess["group_overrides"] = {} to_postprocess["overrides"] = [] to_postprocess["images"] = [] - print("Processing results") + logger.debug("Processing results") for result in results: to_postprocess["tokens"].append(result[0]["tokens"]) to_postprocess["masks"].append(result[0]["masks"]) @@ -444,7 +444,7 @@ class BaseEnv(ABC): setattr(self, key, data[key]) def save_checkpoint(self, step, data=None): - print(f"Saving checkpoint at step {step} with data {data}") + logger.info("Saving checkpoint at step %s with data %s", step, data) if data is None: # Don't have anything to save, abort return @@ -542,7 +542,7 @@ class BaseEnv(ABC): self.config.total_steps = data["num_steps"] if self.config.total_steps == -1: raise ValueError("Total steps not set in config or server!") - print( + logger.info( f"Initialized env with id {self.env_id}: " f"curr_step: {self.curr_step}, " f"checkpoint_dir: {self.checkpoint_dir}, " @@ -779,7 +779,7 @@ class BaseEnv(ABC): with open(filepath, "w") as f: json.dump(eval_result, f, indent=2) - print(f"Evaluation results saved to {filepath}") + logger.info("Evaluation results saved to %s", filepath) # Write samples to JSONL file if provided if samples: @@ -789,7 +789,7 @@ class BaseEnv(ABC): with jsonlines.open(samples_filepath, "w") as writer: for sample in samples: writer.write(sample) - print(f"Evaluation samples saved to {samples_filepath}") + logger.info("Evaluation samples saved to %s", samples_filepath) @retry( stop=stop_after_attempt(3), @@ -823,7 +823,7 @@ class BaseEnv(ABC): elif resp.status >= 400: logging.error(f"Client error: {resp.status}, not retrying") return - print(await resp.text()) + logger.debug(await resp.text()) def _post_json_with_compression( self, @@ -927,7 +927,9 @@ class BaseEnv(ABC): if self.jsonl_writer is not None: self.jsonl_writer.write(group) - print(f"Wrote scored group to {self.config.data_path_to_save_groups}") + logger.info( + "Wrote scored group to %s", self.config.data_path_to_save_groups + ) valid_groups.append(group) @@ -948,7 +950,7 @@ class BaseEnv(ABC): if isinstance(data_to_send_to_api, dict) else f"{len(data_to_send_to_api)} ScoredDataGroups" ) - print(f"Failed to send {data_type_str} after retries: {e}") + logger.error("Failed to send %s after retries: %s", data_type_str, e) async def handle_env( self, item_uuid: str @@ -958,7 +960,7 @@ class BaseEnv(ABC): """ item = self.running_items.get(item_uuid)["item"] if item is None: - print(f"item {item_uuid} not found... returning") + logger.warning("item %s not found... returning", item_uuid) return None start_time = time.time() logger.debug(f"handle_env: Starting with item: {item}") @@ -979,7 +981,7 @@ class BaseEnv(ABC): to_postprocess = await self.postprocess_histories(to_postprocess) except Exception as e: logger.error(f"Error in scoring: {item}") - print(e) + logger.error("Scoring exception: %s", e) to_postprocess = None self.running_items.pop(item_uuid, None) duration = max(0.0, time.time() - start_time) @@ -1120,10 +1122,9 @@ class BaseEnv(ABC): ), ) max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue) - print( + logger.info( f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, " - f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}", - flush=True, + f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}" ) if (self.curr_step == 0) and (len(self.workers) == 0): # We are starting up, so we should just skip the append to the list @@ -1131,10 +1132,9 @@ class BaseEnv(ABC): else: self.workers_added_list.append(max_num_workers - len(self.workers)) if len(self.workers) > max_num_workers: - print( + logger.info( f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, " - "sending workers to backlog", - flush=True, + "sending workers to backlog" ) num_to_reduce = len(self.workers) - max_num_workers running_items_to_remove = list(self.running_items.keys())[:num_to_reduce] @@ -1280,18 +1280,22 @@ class BaseEnv(ABC): # Initialize the processing self.curr_step = 0 - print(f"Starting to process {self.n_groups_to_process} groups...") + logger.info("Starting to process %s groups...", self.n_groups_to_process) # Process the required number of groups while self.curr_step < self.n_groups_to_process: # Get an item to process item = await self.get_next_item() if item is None: - print("No more items to process") + logger.info("No more items to process") break # Process the group - print(f"Processing group {self.curr_step + 1}/{self.n_groups_to_process}") + logger.info( + "Processing group %s/%s", + self.curr_step + 1, + self.n_groups_to_process, + ) # Collect trajectories with the specified group size # Override the group_size temporarily @@ -1314,13 +1318,13 @@ class BaseEnv(ABC): await self.wandb_log() self.curr_step += 1 - print( + logger.info( f"Successfully processed group {self.curr_step}/{self.n_groups_to_process}" ) else: - print("Failed to process group, retrying...") + logger.warning("Failed to process group, retrying...") - print(f"Completed processing {self.curr_step} groups") + logger.info("Completed processing %s groups", self.curr_step) # Close the output file if it's open if self.jsonl_writer is not None: @@ -1354,8 +1358,7 @@ class BaseEnv(ABC): """Handles exceptions with clean output for known error types.""" if isinstance(ex, FailedExecutionException): # Handle argparse errors (already printed by argparse) - print() - print(ex.message.split("error: ")[-1]) + logger.error(ex.message.split("error: ")[-1]) return 2 raise ex @@ -1416,7 +1419,7 @@ class BaseEnv(ABC): if self.config is not None: with open(self.config, "r") as f: yaml_config = yaml.safe_load(f) - print(f"Loaded config from {self.config}") + logger.info("Loaded config from %s", self.config) else: yaml_config = {} @@ -1440,11 +1443,11 @@ class BaseEnv(ABC): yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) # Debug logging for CLI args - print(f"[CLI DEBUG] cli_passed_flags = {cli_passed_flags}") - print(f"[CLI DEBUG] openai_full_prefix = {openai_full_prefix}") - print(f"[CLI DEBUG] oai_cli_passed_args = {oai_cli_passed_args}") - print(f"[CLI DEBUG] yaml_oai_config = {yaml_oai_config}") - + logger.debug("[CLI DEBUG] cli_passed_flags = %s", cli_passed_flags) + logger.debug("[CLI DEBUG] openai_full_prefix = %s", openai_full_prefix) + logger.debug("[CLI DEBUG] oai_cli_passed_args = %s", oai_cli_passed_args) + logger.debug("[CLI DEBUG] yaml_oai_config = %s", yaml_oai_config) + # Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided # This allows any environment to use --openai.* CLI args without modifying config_init # Use a new variable to avoid UnboundLocalError from closure scoping @@ -1472,19 +1475,21 @@ class BaseEnv(ABC): if isinstance(default_openai_config_, APIServerConfig) and isinstance( yaml_oai_config, dict ): - print( - f"[CLI DEBUG] default_openai_config_.model_dump() = {default_openai_config_.model_dump()}" + logger.debug( + "[CLI DEBUG] default_openai_config_.model_dump() = %s", + default_openai_config_.model_dump(), ) openai_config_dict = merge_dicts( default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) yaml_oai_config, oai_cli_passed_args, ) - print( - f"[CLI DEBUG] openai_config_dict after merge = {openai_config_dict}" + logger.debug( + "[CLI DEBUG] openai_config_dict after merge = %s", + openai_config_dict, ) else: - print( + logger.debug( "[CLI DEBUG] Not merging: default_openai_config_ " f"type={type(default_openai_config_)}, " f"yaml_oai_config type={type(yaml_oai_config)}" @@ -1637,7 +1642,7 @@ class BaseEnv(ABC): if self.config is not None: with open(self.config, "r") as f: yaml_config = yaml.safe_load(f) - print(f"Loaded config from {self.config}") + logger.info("Loaded config from %s", self.config) else: yaml_config = {} @@ -1810,7 +1815,7 @@ class BaseEnv(ABC): "data_path_to_save_groups must be set for process mode" ) - print( + logger.info( f"Processing {env_config.total_steps} groups of " f"{env_config.group_size} responses and " f"writing to {env_config.data_path_to_save_groups}" @@ -1906,7 +1911,7 @@ class BaseEnv(ABC): if self.config is not None: with open(self.config, "r") as f: yaml_config = yaml.safe_load(f) - print(f"Loaded config from {self.config}") + logger.info("Loaded config from %s", self.config) else: yaml_config = {} @@ -2092,7 +2097,7 @@ class BaseEnv(ABC): yaml.dump( config_dict, f, default_flow_style=False, sort_keys=False ) - print(f"Dumped evaluate config to {config_filepath}") + logger.info("Dumped evaluate config to %s", config_filepath) # --- Create and Run Environment --- # Create the environment instance @@ -2103,7 +2108,7 @@ class BaseEnv(ABC): testing=server_manager_config.testing, ) - print("Running evaluation...") + logger.info("Running evaluation...") # Handle the case where we might already be in an event loop try: loop = asyncio.get_running_loop() @@ -2112,6 +2117,6 @@ class BaseEnv(ABC): except RuntimeError: asyncio.run(env._run_evaluate()) - print("Evaluation completed.") + logger.info("Evaluation completed.") return CliEvaluateConfig diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index cb14d210..ff472561 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -8,6 +8,7 @@ This wrapper maintains a tree structure of sequences, where: """ import os +import logging import time import uuid import warnings @@ -23,6 +24,8 @@ from pydantic import BaseModel from atroposlib.envs.server_handling.server_baseline import APIServer +logger = logging.getLogger(__name__) + class SequenceNode(BaseModel): """ @@ -292,16 +295,14 @@ class ManagedServer: if self._debug_requests_enabled(): msg_count = len(messages) prompt_preview = prompt.replace("\n", "\\n")[:600] - print( - f"[ATROPOS_REQ_DEBUG] chat_completion messages={msg_count} " - f"n={completion_kwargs.get('n')} max_tokens={completion_kwargs.get('max_tokens')} " - f"temperature={completion_kwargs.get('temperature')}", - flush=True, - ) - print( - f"[ATROPOS_REQ_DEBUG] prompt_preview={prompt_preview!r}", - flush=True, + logger.debug( + "[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s", + msg_count, + completion_kwargs.get("n"), + completion_kwargs.get("max_tokens"), + completion_kwargs.get("temperature"), ) + logger.debug("[ATROPOS_REQ_DEBUG] prompt_preview=%r", prompt_preview) # Set model name if not provided if "model" not in completion_kwargs: From 216c1f5899a6e493fe69b18458d73dc187fa284d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Feb 2026 21:17:54 +0000 Subject: [PATCH 23/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- atroposlib/envs/base.py | 6 ++++-- atroposlib/envs/server_handling/managed_server.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index a466249c..3d3b6c20 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -1445,9 +1445,11 @@ class BaseEnv(ABC): # Debug logging for CLI args logger.debug("[CLI DEBUG] cli_passed_flags = %s", cli_passed_flags) logger.debug("[CLI DEBUG] openai_full_prefix = %s", openai_full_prefix) - logger.debug("[CLI DEBUG] oai_cli_passed_args = %s", oai_cli_passed_args) + logger.debug( + "[CLI DEBUG] oai_cli_passed_args = %s", oai_cli_passed_args + ) logger.debug("[CLI DEBUG] yaml_oai_config = %s", yaml_oai_config) - + # Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided # This allows any environment to use --openai.* CLI args without modifying config_init # Use a new variable to avoid UnboundLocalError from closure scoping diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index ff472561..c1358dc6 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -7,8 +7,8 @@ This wrapper maintains a tree structure of sequences, where: - Branching occurs organically from different contexts and n > 1 completions """ -import os import logging +import os import time import uuid import warnings