diff --git a/README.md b/README.md index b7ed312e..3b533a9b 100644 --- a/README.md +++ b/README.md @@ -256,6 +256,48 @@ Atropos repo contains an example trainer that should primarily be used as a refe To use the example trainer, see this page: [training example guide](example_trainer/README.md) +## On-Policy Distillation (API + ScoredDataGroup Contract) + +Atropos now supports OPD at the transport layer by carrying distillation arrays +through `ScoredDataGroup` and the API queue/batch endpoints. + +### Scope of this change + +- No teacher fetching/orchestration in `BaseEnv`. +- Environments or external pipelines are responsible for populating distillation arrays. +- API stores and returns those arrays unchanged. + +### Distillation payload fields + +Each scored group may include: + +- `distill_token_ids`: shape `[sequence][position][top_k]` +- `distill_logprobs`: shape `[sequence][position][top_k]` + +These fields are optional, and when present are forwarded from: + +- environment -> `/scored_data` or `/scored_data_list` +- API queue -> `/batch` -> trainer + +### Minimal producer example (environment side) + +```python +scores["distill_token_ids"] = distill_token_ids +scores["distill_logprobs"] = distill_logprobs +``` + +### Minimal consumer check (trainer/debug side) + +```bash +curl -s http://localhost:8002/latest_example | jq '{has_ids:(.distill_token_ids!=null), has_lps:(.distill_logprobs!=null)}' +``` + +### Notes + +- The API does not validate cross-field semantics beyond schema typing. +- Trainers should validate alignment assumptions they require (sequence length, per-position top-k, etc.). +- Teacher-side architecture and prompt/rendering strategy are intentionally out of scope for this PR. + --- ## Testing and Debugging Tools diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 4a94c6d8..b9327b75 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -1,4 +1,5 @@ import gzip +import logging import time import uuid from typing import Any, Dict, List, Optional @@ -25,6 +26,7 @@ MIN_ENV_WEIGHT = ( # Message import removed - using Dict[str, Any] for more flexible validation app = FastAPI(title="AtroposLib API") +logger = logging.getLogger(__name__) app.add_middleware( CORSMiddleware, @@ -145,6 +147,10 @@ class ScoredData(BaseModel): group_overrides: Optional[dict] = None images: Optional[Any] = None env_id: Optional[int] = None # ID of the environment that generated this data + # On-policy distillation (new format): parallel token ids + logprobs. + # Shape for both: [sequence][position][top_k] + distill_token_ids: Optional[List[List[List[int]]]] = None + distill_logprobs: Optional[List[List[List[float]]]] = None @field_validator("messages", mode="before") @classmethod @@ -182,6 +188,8 @@ def _scored_data_to_dict(scored_data: ScoredData) -> Dict[str, Any]: "group_overrides": scored_data.group_overrides, "images": scored_data.images, "env_id": scored_data.env_id, + "distill_token_ids": scored_data.distill_token_ids, + "distill_logprobs": scored_data.distill_logprobs, } @@ -385,7 +393,10 @@ async def get_batch(): app.state.curr_batch.append(batch) curr_batch = app.state.curr_batch.pop() # check length before sending - print(f"Sending batch of {sum(len(x['tokens']) for x in curr_batch)} sequences") + logger.info( + "Sending batch of %s sequences", + sum(len(x["tokens"]) for x in curr_batch), + ) return {"batch": curr_batch} diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index e9b672a6..3d3b6c20 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -66,6 +66,10 @@ class ScoredDataGroup(TypedDict): group_overrides: Optional[Dict] overrides: Optional[List[Dict]] images: Optional[Any] + # On-policy distillation (new format): parallel token ids + logprobs. + # distill_token_ids/distill_logprobs are [sequence][position][top_k] + distill_token_ids: Optional[List[List[List[int]]]] + distill_logprobs: Optional[List[List[List[float]]]] class ScoredDataItem(TypedDict): @@ -78,6 +82,9 @@ class ScoredDataItem(TypedDict): group_overrides: Optional[Dict] overrides: Optional[Dict] images: Optional[Any] + # On-policy distillation (new format): parallel token ids + logprobs per position. + distill_token_ids: Optional[List[List[int]]] + distill_logprobs: Optional[List[List[float]]] class EvalHandlingEnum(Enum): @@ -279,7 +286,7 @@ class BaseEnv(ABC): counter += 1 path_changed = True if path_changed: - print( + logger.info( f"Changed data path to {self.config.data_path_to_save_groups} because {original_path} already exists." # noqa: E501 ) @@ -357,7 +364,7 @@ class BaseEnv(ABC): to_postprocess["group_overrides"] = {} to_postprocess["overrides"] = [] to_postprocess["images"] = [] - print("Processing results") + logger.debug("Processing results") for result in results: to_postprocess["tokens"].append(result[0]["tokens"]) to_postprocess["masks"].append(result[0]["masks"]) @@ -437,7 +444,7 @@ class BaseEnv(ABC): setattr(self, key, data[key]) def save_checkpoint(self, step, data=None): - print(f"Saving checkpoint at step {step} with data {data}") + logger.info("Saving checkpoint at step %s with data %s", step, data) if data is None: # Don't have anything to save, abort return @@ -535,7 +542,7 @@ class BaseEnv(ABC): self.config.total_steps = data["num_steps"] if self.config.total_steps == -1: raise ValueError("Total steps not set in config or server!") - print( + logger.info( f"Initialized env with id {self.env_id}: " f"curr_step: {self.curr_step}, " f"checkpoint_dir: {self.checkpoint_dir}, " @@ -772,7 +779,7 @@ class BaseEnv(ABC): with open(filepath, "w") as f: json.dump(eval_result, f, indent=2) - print(f"Evaluation results saved to {filepath}") + logger.info("Evaluation results saved to %s", filepath) # Write samples to JSONL file if provided if samples: @@ -782,7 +789,7 @@ class BaseEnv(ABC): with jsonlines.open(samples_filepath, "w") as writer: for sample in samples: writer.write(sample) - print(f"Evaluation samples saved to {samples_filepath}") + logger.info("Evaluation samples saved to %s", samples_filepath) @retry( stop=stop_after_attempt(3), @@ -816,7 +823,7 @@ class BaseEnv(ABC): elif resp.status >= 400: logging.error(f"Client error: {resp.status}, not retrying") return - print(await resp.text()) + logger.debug(await resp.text()) def _post_json_with_compression( self, @@ -888,6 +895,8 @@ class BaseEnv(ABC): group.setdefault("ref_logprobs", None) group.setdefault("overrides", None) group.setdefault("group_overrides", None) + group.setdefault("distill_token_ids", None) + group.setdefault("distill_logprobs", None) for mask in group["masks"]: self.completion_lengths.append(sum(m != -100 for m in mask)) @@ -905,7 +914,12 @@ class BaseEnv(ABC): if self.config.include_messages and group.get("messages") is None: group["messages"] = [ - self.tokenizer.decode(group["tokens"][i]) + [ + { + "role": "user", + "content": self.tokenizer.decode(group["tokens"][i]), + } + ] for i in range(len(group["tokens"])) ] @@ -913,7 +927,9 @@ class BaseEnv(ABC): if self.jsonl_writer is not None: self.jsonl_writer.write(group) - print(f"Wrote scored group to {self.config.data_path_to_save_groups}") + logger.info( + "Wrote scored group to %s", self.config.data_path_to_save_groups + ) valid_groups.append(group) @@ -934,7 +950,7 @@ class BaseEnv(ABC): if isinstance(data_to_send_to_api, dict) else f"{len(data_to_send_to_api)} ScoredDataGroups" ) - print(f"Failed to send {data_type_str} after retries: {e}") + logger.error("Failed to send %s after retries: %s", data_type_str, e) async def handle_env( self, item_uuid: str @@ -944,7 +960,7 @@ class BaseEnv(ABC): """ item = self.running_items.get(item_uuid)["item"] if item is None: - print(f"item {item_uuid} not found... returning") + logger.warning("item %s not found... returning", item_uuid) return None start_time = time.time() logger.debug(f"handle_env: Starting with item: {item}") @@ -965,7 +981,7 @@ class BaseEnv(ABC): to_postprocess = await self.postprocess_histories(to_postprocess) except Exception as e: logger.error(f"Error in scoring: {item}") - print(e) + logger.error("Scoring exception: %s", e) to_postprocess = None self.running_items.pop(item_uuid, None) duration = max(0.0, time.time() - start_time) @@ -1106,10 +1122,9 @@ class BaseEnv(ABC): ), ) max_num_workers = max(max_num_workers, min_workers_to_fill_self_queue) - print( + logger.info( f"max_num_workers: {max_num_workers}, queue size: {self.status_dict['queue_size']}, " - f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}", - flush=True, + f"workers: {len(self.workers)}, self_queue_size: {self.status_dict['self_queue_size']}" ) if (self.curr_step == 0) and (len(self.workers) == 0): # We are starting up, so we should just skip the append to the list @@ -1117,10 +1132,9 @@ class BaseEnv(ABC): else: self.workers_added_list.append(max_num_workers - len(self.workers)) if len(self.workers) > max_num_workers: - print( + logger.info( f"len(self.workers) > max_num_workers: {len(self.workers)} > {max_num_workers}, " - "sending workers to backlog", - flush=True, + "sending workers to backlog" ) num_to_reduce = len(self.workers) - max_num_workers running_items_to_remove = list(self.running_items.keys())[:num_to_reduce] @@ -1266,18 +1280,22 @@ class BaseEnv(ABC): # Initialize the processing self.curr_step = 0 - print(f"Starting to process {self.n_groups_to_process} groups...") + logger.info("Starting to process %s groups...", self.n_groups_to_process) # Process the required number of groups while self.curr_step < self.n_groups_to_process: # Get an item to process item = await self.get_next_item() if item is None: - print("No more items to process") + logger.info("No more items to process") break # Process the group - print(f"Processing group {self.curr_step + 1}/{self.n_groups_to_process}") + logger.info( + "Processing group %s/%s", + self.curr_step + 1, + self.n_groups_to_process, + ) # Collect trajectories with the specified group size # Override the group_size temporarily @@ -1300,13 +1318,13 @@ class BaseEnv(ABC): await self.wandb_log() self.curr_step += 1 - print( + logger.info( f"Successfully processed group {self.curr_step}/{self.n_groups_to_process}" ) else: - print("Failed to process group, retrying...") + logger.warning("Failed to process group, retrying...") - print(f"Completed processing {self.curr_step} groups") + logger.info("Completed processing %s groups", self.curr_step) # Close the output file if it's open if self.jsonl_writer is not None: @@ -1340,8 +1358,7 @@ class BaseEnv(ABC): """Handles exceptions with clean output for known error types.""" if isinstance(ex, FailedExecutionException): # Handle argparse errors (already printed by argparse) - print() - print(ex.message.split("error: ")[-1]) + logger.error(ex.message.split("error: ")[-1]) return 2 raise ex @@ -1402,7 +1419,7 @@ class BaseEnv(ABC): if self.config is not None: with open(self.config, "r") as f: yaml_config = yaml.safe_load(f) - print(f"Loaded config from {self.config}") + logger.info("Loaded config from %s", self.config) else: yaml_config = {} @@ -1424,31 +1441,61 @@ class BaseEnv(ABC): cli_passed_flags, openai_full_prefix ) # CLI args yaml_oai_config = yaml_config.get(OPENAI_NAMESPACE, {}) - if isinstance(default_server_configs, ServerBaseline) and ( + + # Debug logging for CLI args + logger.debug("[CLI DEBUG] cli_passed_flags = %s", cli_passed_flags) + logger.debug("[CLI DEBUG] openai_full_prefix = %s", openai_full_prefix) + logger.debug( + "[CLI DEBUG] oai_cli_passed_args = %s", oai_cli_passed_args + ) + logger.debug("[CLI DEBUG] yaml_oai_config = %s", yaml_oai_config) + + # Auto-convert ServerBaseline to APIServerConfig when CLI/YAML overrides are provided + # This allows any environment to use --openai.* CLI args without modifying config_init + # Use a new variable to avoid UnboundLocalError from closure scoping + effective_server_configs = default_server_configs + if isinstance(effective_server_configs, ServerBaseline) and ( oai_cli_passed_args or yaml_oai_config ): - raise ValueError( - "ServerBaseline is not compatible with OpenAI-namespaced CLI arguments. Please edit `config_init` directly or use APIServerConfig." # noqa: E501 + # Convert ServerBaseline to APIServerConfig, preserving common fields + baseline_dict = effective_server_configs.model_dump() + effective_server_configs = APIServerConfig(**baseline_dict) + logger.info( + "Auto-converted ServerBaseline to APIServerConfig for CLI/YAML overrides" ) + if ( - isinstance(default_server_configs, list) - and len(default_server_configs) == 1 + isinstance(effective_server_configs, list) + and len(effective_server_configs) == 1 ): # can't use the same var name because it shadows the class variable and we get an error - default_openai_config_ = default_server_configs[0] + default_openai_config_ = effective_server_configs[0] else: - default_openai_config_ = default_server_configs + default_openai_config_ = effective_server_configs if isinstance(yaml_oai_config, list) and len(yaml_oai_config) == 1: yaml_oai_config = yaml_oai_config[0] if isinstance(default_openai_config_, APIServerConfig) and isinstance( yaml_oai_config, dict ): + logger.debug( + "[CLI DEBUG] default_openai_config_.model_dump() = %s", + default_openai_config_.model_dump(), + ) openai_config_dict = merge_dicts( default_openai_config_.model_dump(), # Default APIServerConfig (or from class init) yaml_oai_config, oai_cli_passed_args, ) + logger.debug( + "[CLI DEBUG] openai_config_dict after merge = %s", + openai_config_dict, + ) else: + logger.debug( + "[CLI DEBUG] Not merging: default_openai_config_ " + f"type={type(default_openai_config_)}, " + f"yaml_oai_config type={type(yaml_oai_config)}" + ) openai_config_dict = {} # 3. Server Manager Configuration (slurm, testing - not namespaced) @@ -1487,7 +1534,7 @@ class BaseEnv(ABC): # Determine the final server_configs, handling single, multiple servers, and overrides. openai_configs = resolve_openai_configs( - default_server_configs=default_server_configs, + default_server_configs=effective_server_configs, openai_config_dict=openai_config_dict, yaml_config=yaml_config, cli_passed_flags=cli_passed_flags, @@ -1597,7 +1644,7 @@ class BaseEnv(ABC): if self.config is not None: with open(self.config, "r") as f: yaml_config = yaml.safe_load(f) - print(f"Loaded config from {self.config}") + logger.info("Loaded config from %s", self.config) else: yaml_config = {} @@ -1770,7 +1817,7 @@ class BaseEnv(ABC): "data_path_to_save_groups must be set for process mode" ) - print( + logger.info( f"Processing {env_config.total_steps} groups of " f"{env_config.group_size} responses and " f"writing to {env_config.data_path_to_save_groups}" @@ -1866,7 +1913,7 @@ class BaseEnv(ABC): if self.config is not None: with open(self.config, "r") as f: yaml_config = yaml.safe_load(f) - print(f"Loaded config from {self.config}") + logger.info("Loaded config from %s", self.config) else: yaml_config = {} @@ -2052,7 +2099,7 @@ class BaseEnv(ABC): yaml.dump( config_dict, f, default_flow_style=False, sort_keys=False ) - print(f"Dumped evaluate config to {config_filepath}") + logger.info("Dumped evaluate config to %s", config_filepath) # --- Create and Run Environment --- # Create the environment instance @@ -2063,7 +2110,7 @@ class BaseEnv(ABC): testing=server_manager_config.testing, ) - print("Running evaluation...") + logger.info("Running evaluation...") # Handle the case where we might already be in an event loop try: loop = asyncio.get_running_loop() @@ -2072,6 +2119,6 @@ class BaseEnv(ABC): except RuntimeError: asyncio.run(env._run_evaluate()) - print("Evaluation completed.") + logger.info("Evaluation completed.") return CliEvaluateConfig diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index 0918c325..c1358dc6 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -7,6 +7,8 @@ This wrapper maintains a tree structure of sequences, where: - Branching occurs organically from different contexts and n > 1 completions """ +import logging +import os import time import uuid import warnings @@ -22,6 +24,8 @@ from pydantic import BaseModel from atroposlib.envs.server_handling.server_baseline import APIServer +logger = logging.getLogger(__name__) + class SequenceNode(BaseModel): """ @@ -131,6 +135,10 @@ class ManagedServer: # Fallback for tokenizers without chat template return "\n".join([f"{m['role']}: {m['content']}" for m in messages]) + def _debug_requests_enabled(self) -> bool: + """Enable verbose request construction logs with ATROPOS_DEBUG_REQUESTS=1.""" + return os.getenv("ATROPOS_DEBUG_REQUESTS", "0") == "1" + def _find_extending_node(self, input_text: str) -> Optional[SequenceNode]: """ Find a node that this input extends (default mode). @@ -284,6 +292,17 @@ class ManagedServer: completion_kwargs = kwargs.copy() completion_kwargs["prompt"] = prompt completion_kwargs.pop("messages", None) + if self._debug_requests_enabled(): + msg_count = len(messages) + prompt_preview = prompt.replace("\n", "\\n")[:600] + logger.debug( + "[ATROPOS_REQ_DEBUG] chat_completion messages=%s n=%s max_tokens=%s temperature=%s", + msg_count, + completion_kwargs.get("n"), + completion_kwargs.get("max_tokens"), + completion_kwargs.get("temperature"), + ) + logger.debug("[ATROPOS_REQ_DEBUG] prompt_preview=%r", prompt_preview) # Set model name if not provided if "model" not in completion_kwargs: diff --git a/atroposlib/envs/server_handling/openai_server.py b/atroposlib/envs/server_handling/openai_server.py index fce40f80..fecc5828 100644 --- a/atroposlib/envs/server_handling/openai_server.py +++ b/atroposlib/envs/server_handling/openai_server.py @@ -198,7 +198,21 @@ def resolve_openai_configs( raise FailedExecutionException( f"Error parsing multi-server OpenAI configuration from YAML under '{OPENAI_NAMESPACE}': {e}" ) from e + elif isinstance(default_server_configs, APIServerConfig): + # Check APIServerConfig BEFORE ServerBaseline since APIServerConfig inherits from ServerBaseline + logger.info( + "Using single OpenAI server configuration based on merged settings (default/YAML/CLI)." + ) + try: + final_openai_config = APIServerConfig(**openai_config_dict) + except Exception as e: + raise FailedExecutionException( + f"Error creating final OpenAI configuration from merged settings: {e}\n" + f"Merged Dict: {openai_config_dict}" + ) from e + server_configs = final_openai_config elif isinstance(default_server_configs, ServerBaseline): + # Pure ServerBaseline (not APIServerConfig) - no CLI overrides possible logger.info("Using ServerBaseline configuration.") server_configs = default_server_configs elif is_multi_server_default: diff --git a/example_trainer/README.md b/example_trainer/README.md index aee1831d..34c7b7a7 100644 --- a/example_trainer/README.md +++ b/example_trainer/README.md @@ -8,6 +8,16 @@ This example uses `vLLM` for efficient inference during the (simulated) data gen **Note:** This script is intended as a *reference example* for API integration and basic training setup. It is not optimized for large-scale, efficient training. +## On-Policy Distillation Scope + +The current OPD integration in Atropos is transport-only: + +- `ScoredDataGroup` / API payloads support `distill_token_ids` and `distill_logprobs`. +- Atropos API stores and returns those fields through `/scored_data` and `/batch`. +- Teacher orchestration (teacher endpoint calls, prompt rendering, top-k fetching) is intentionally out of scope in this PR. + +If you train with distillation, provide the two distill arrays from your environment or external data pipeline before posting to the API. + ### Custom vLLM Server The `vllm_api_server.py` file in this directory provides a customized vLLM API server implementation based on vLLM's native API. This server exposes enhanced endpoints for token and logprob tracking. The `VLLMServer` class in `atroposlib/envs/server_handling/vllm_server.py` can connect to this server for direct access to vLLM's `/generate` endpoint with full token-level logprobs.