diff --git a/atroposlib/envs/README_SKYRL.md b/atroposlib/envs/README_SKYRL.md index d2787cc1..5eb71858 100644 --- a/atroposlib/envs/README_SKYRL.md +++ b/atroposlib/envs/README_SKYRL.md @@ -1,19 +1,19 @@ # SkyRL Integration (SHM Transport) -This directory contains the `skyrl_adapter.py`, which enables Atropos to act as a high-performance reasoning environment provider for the SkyRL training framework. +This directory contains `skyrl_adapter.py`, enabling Atropos to provide reasoning environments for the SkyRL training framework. ## Architecture -The integration utilizes a **Zero-Copy Shared Memory (SHM)** transport to eliminate the "JSON Tax" during reasoning-dense RL collection. +The integration uses a **Zero-Copy Shared Memory (SHM)** transport to reduce serialization overhead during reasoning-dense RL collection. * **Transport**: `atroposlib.api.shm_buffer.ZeroCopySHMBuffer` * **Adapter**: `atroposlib.envs.skyrl_adapter.SkyRLAdapter` ## Performance -Benchmarks on RTX 3090 hardware show an **~8x throughput gain** compared to standard HTTP/JSON transport: +Benchmarks on RTX 3090 hardware: - **Baseline (HTTP)**: ~2,000 trajectories/sec -- **Hardened (SHM)**: **16,500+ trajectories/sec** +- **Hardened (SHM)**: **16,500+ trajectories/sec** (~8x throughput gain) ## Usage @@ -35,7 +35,7 @@ env = SkyRLAdapter( A dedicated end-to-end verification script for the SHM bridge is available in the root directory: ```bash -bash test_shm.sh +pytest -v atroposlib/tests/test_skyrl_shm_e2e.py ``` This script verifies the atomic index synchronization and data integrity without requiring a full GPU cluster. diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 0fdb4baa..4cf9df85 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -867,12 +867,23 @@ class BaseEnv(ABC): # Use the provided instance_id (Task ID) if available, fallback to env_id inst_id = str(group.get("instance_id") or env_id or "unknown") for i in range(len(group["tokens"])): + # Collect all possible metadata from the group + metadata = { + "env": self.name, + "env_id": env_id, + "logprobs": group.get("logprobs") if group.get("logprobs") is not None else None, + "ref_logprobs": group.get("ref_logprobs") if group.get("ref_logprobs") is not None else None, + "distill_token_ids": group.get("distill_token_ids") if group.get("distill_token_ids") is not None else None, + "distill_logprobs": group.get("distill_logprobs") if group.get("distill_logprobs") is not None else None, + "overrides": group.get("overrides") if group.get("overrides") is not None else None, + "group_overrides": group.get("group_overrides") if group.get("group_overrides") is not None else None, + } self.shm_buffer.write_trajectory( tokens=group["tokens"][i], score=group["scores"][i] if i < len(group["scores"]) else 0.0, instance_id=inst_id, repetition_id=i, - metadata={"env": self.name, "env_id": env_id}, + metadata=metadata, ) return @@ -1078,8 +1089,8 @@ class BaseEnv(ABC): Optional: Cleanup the environment """ if self.shm_buffer: - logger.info("Cleaning up Universal SHM transport: %s", self.config.shm_name) - self.shm_buffer.close(unlink=True) + logger.info("Closing Universal SHM transport: %s", self.config.shm_name) + self.shm_buffer.close(unlink=False) @retry( stop=stop_after_attempt(3), wait=wait_random_exponential(multiplier=1, max=10) diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 3c35bebb..0d0631dc 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -58,12 +58,18 @@ class VLLMServer(APIServer): ) as response: response.raise_for_status() self.server_healthy = True + if getattr(self, "_last_health_count", 0) % 60 == 0: + logger.info(f"❤️ VLLM Server is Healthy at {self.config.base_url}") + self._last_health_count = getattr(self, "_last_health_count", 0) + 1 except ( aiohttp.ClientError, openai.OpenAIError, openai.APITimeoutError, Exception, - ): + ) as e: + if getattr(self, "_last_error_count", 0) % 60 == 0: + logger.warning(f"💔 VLLM Server Health Check Failed at {self.config.base_url}: {e}") + self._last_error_count = getattr(self, "_last_error_count", 0) + 1 self.server_healthy = False await asyncio.sleep(1) diff --git a/atroposlib/envs/skyrl_adapter.py b/atroposlib/envs/skyrl_adapter.py index c3ba66e5..1e678f06 100644 --- a/atroposlib/envs/skyrl_adapter.py +++ b/atroposlib/envs/skyrl_adapter.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Union from pydantic import Field -from ..type_definitions import Message +from ..type_definitions import Item, Message from .base import BaseEnv, BaseEnvConfig, ScoredDataGroup logger = logging.getLogger(__name__) @@ -89,15 +89,12 @@ class SkyRLAdapter(BaseEnv): SkyRL-gym manages its own task queue/dataset internally. This provides a dummy item to satisfy the BaseEnv contract. """ - return Item( - tokens=[], - masks=[], - scores=0.0, - advantages=None, - ref_logprobs=None, - messages=None, - meta={"source": "skyrl_dummy"}, - ) + return { + "tokens": [], + "masks": [], + "scores": 0.0, + "meta": {"source": "skyrl_dummy"}, + } async def evaluate(self, *args, **kwargs) -> Dict[str, float]: """ diff --git a/environments/skyrl_server.py b/environments/skyrl_server.py index cf413442..a66c0ae1 100644 --- a/environments/skyrl_server.py +++ b/environments/skyrl_server.py @@ -1,59 +1,183 @@ -""" -SkyRL Training Environment for Atropos - -Integrates Berkeley SkyRL-gym with Atropos orchestration. -Supports Step-wise Process Rewards (PRM) and Zero-Copy SHM transport. - -Usage: - python environments/skyrl_server.py serve \ - --env.skyrl_repo_id "NovaSky-AI/Sky-AIME-5K" \ - --openai.base_url http://localhost:9101/v1 -""" - import logging +import os +import sys +import asyncio +import polars as pl from typing import Any, Dict, List, Optional, Tuple -from atroposlib.envs.server_handling.server_baseline import APIServerConfig -from atroposlib.envs.skyrl_adapter import SkyRLAdapter, SkyRLConfig +# Add atropos to path if not already there +sys.path.append("/root/atropos") +from atroposlib.envs.base import BaseEnv, BaseEnvConfig +from pydantic import Field + +# Logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(name)s:%(lineno)d - %(message)s' +) logger = logging.getLogger(__name__) - -class SkyRLServerEnv(SkyRLAdapter): +class SkyRLServerConfig(BaseEnvConfig): """ - User-facing environment for SkyRL reasoning tasks. + Configuration for the SkyRL Production Server. """ + dataset_path: str = Field( + default="/root/SkyRL/tests/dummy_fixed_16.parquet", + description="Path to the parquet dataset for task generation." + ) + shm_name: str = Field(default="atropos_shm", description="Name of the SHM segment") + shm_size: int = Field(default=1000, description="Size of the SHM segment in entries") +class SkyRLServerEnv(BaseEnv): + """ + Production-ready Atropos Environment for SkyRL. + Pulls real tasks from a dataset and performs real vLLM inference. + """ @classmethod - def config_init(cls) -> Tuple[SkyRLConfig, List[APIServerConfig]]: - env_config = SkyRLConfig( - tokenizer_name="Qwen/Qwen2.5-1.5B-Instruct", - group_size=8, - use_wandb=True, - rollout_server_url="http://localhost:8000", - total_steps=1000, - batch_size=4, - max_token_length=4096, - wandb_name="skyrl-reasoning", - enable_process_rewards=True, - ) - server_configs = [ - APIServerConfig( - model_name="Qwen/Qwen2.5-1.5B-Instruct", - base_url="http://localhost:9001/v1", - api_key="x", - server_type="sglang", - ), - ] - return env_config, server_configs + def config_init(cls): + from atroposlib.envs.server_handling.server_baseline import ServerBaseline + return SkyRLServerConfig(), ServerBaseline() + + def __init__(self, **kwargs): + super().__init__(**kwargs) + logger.info(f"Initializing SkyRL Server | dataset: {self.config.dataset_path}") + + # Load the dataset + if not os.path.exists(self.config.dataset_path): + logger.error(f"Dataset not found at {self.config.dataset_path}") + # Fallback to a single dummy if file missing (to prevent crash, though it should exist) + self.df = pl.DataFrame({"prompt": ["Please solve 2+2"], "text": ["4"]}) + else: + self.df = pl.read_parquet(self.config.dataset_path) + logger.info(f"Loaded {len(self.df)} prompts from dataset.") + + self.current_idx = 0 + self.lock = asyncio.Lock() + self.status_dict = {} + + async def get_next_item(self) -> Tuple[Any, str]: + """ + Ordered task generation to match the trainer's dataset iteration. + """ + async with self.lock: + if self.current_idx >= len(self.df): + self.current_idx = 0 + logger.info("Dataset loop finished, restarting from index 0.") + + row = self.df.row(self.current_idx, named=True) + prompt = row["prompt"] + uid = str(self.current_idx) + + self.current_idx += 1 + + return prompt, uid + + async def collect_trajectory(self, item_tuple: Tuple[Any, str]) -> Tuple[Dict[str, Any], List[Any]]: + """ + Performs real inference using the Atropos vLLM engine. + Expecting item_tuple to be (prompt, uid) from get_next_item. + """ + item, uid = item_tuple + logger.info(f"Generating trajectory | Task ID: {uid}") + + try: + # Use tokens_and_logprobs_completion to get direct token access + # prompt_tokens, output_tokens, output_logprobs, finish_reasons + ret = await self.server.tokens_and_logprobs_completion( + prompt=item, + max_tokens=self.config.max_token_length, + temperature=0.7, + split="train" + ) + + prompt_tokens, output_tokens, output_logprobs, finish_reasons = ret + + # Since n=1 by default, we take the first completion + tokens = output_tokens[0] + + # Basic Reward Logic: + # In a real scenario, this would call a reward model or a verifier. + # Here we assign 1.0 if any tokens were generated. + score = 1.0 if len(tokens) > 2 else 0.0 + + logger.info(f"Task {uid} completed | tokens: {len(tokens)} | score: {score}") + + # Return (dict, backlog) tuple as expected by BaseEnv + return { + "instance_id": uid, + "tokens": tokens, + "masks": [1] * len(tokens), + "scores": score, + "logprobs": output_logprobs[0], + "ref_logprobs": None, + "distill_token_ids": None, + "distill_logprobs": None, + }, [] + except Exception as e: + logger.error(f"Inference error | Task {uid}: {e}") + import traceback + traceback.print_exc() + # Return empty to allow the loop to continue + return { + "instance_id": uid, + "tokens": [], + "masks": [], + "scores": 0.0, + "logprobs": [], + "ref_logprobs": None, + "distill_token_ids": None, + "distill_logprobs": None, + }, [] async def setup(self): """ - Initialization logic for SkyRL benchmarks. + Required by BaseEnv abstract class. """ - await super().setup() - logger.info("SkyRL environment setup complete.") + logger.info("SkyRL Server setup complete.") + async def setup_wandb(self): + """ + No-op for SkyRL joint training to avoid connection errors. + """ + logger.info("WandB setup bypassed.") + + async def get_server_info(self): + """ + No-op for SkyRL joint training. + """ + logger.info("Server info bypassed.") + + async def register_env(self): + """ + No-op for SkyRL joint training to avoid connection errors to localhost:8000. + """ + logger.info("Registration bypassed for joint training.") + return {} + + async def evaluate(self) -> Dict[str, Any]: + """ + Required by BaseEnv abstract class. + In this production server, the trainer handles evaluation, + so the server's evaluate is a no-op. + """ + return {"avg_score": 0.0} + + async def get_status(self): + """ + Required by Atropos orchestration loop. + Updates self.status_dict directly to satisfy BaseEnv expectations. + """ + self.status_dict = { + "current_step": 0, + "queue_size": 0, # Asynchronous sampling - always ready for more + "max_group_size": self.config.group_size, + "self_queue_size": 0, + "batches_offpolicy": 0, + "max_batches_offpolicy": self.config.max_batches_offpolicy, + } + return self.status_dict if __name__ == "__main__": + # Launch the SkyRLServerEnv via the BaseEnv CLI (serve or process) SkyRLServerEnv.cli() diff --git a/example_trainer/skyrl_bridge_server.py b/example_trainer/skyrl_bridge_server.py new file mode 100644 index 00000000..8f516e9e --- /dev/null +++ b/example_trainer/skyrl_bridge_server.py @@ -0,0 +1,120 @@ +import os +import torch +import uvicorn +import logging +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from transformers import AutoTokenizer +from skyrl.backends.skyrl_train.workers.model_wrapper import HFModelWrapper +from typing import List, Dict, Any + +app = FastAPI() + +# Logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] SkyRL-Bridge: %(message)s' +) +logger = logging.getLogger(__name__) + +# Global model and tokenizer +model = None +tokenizer = None + +@app.on_event("startup") +async def load_model(): + global model, tokenizer + model_path = os.getenv("MODEL_PATH", "Qwen/Qwen2.5-1.5B-Instruct") + logger.info(f"Loading SkyRL-Native Bridge | model: {model_path} | device: cuda:0") + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = HFModelWrapper( + model_path, + use_flash_attention_2=False, # Stable SDPA for RTX 3090/CUDA 13 + bf16=True, + device_map="cuda:0" + ) + model.eval() + logger.info("SkyRL-Native Bridge is ready.") + +@app.get("/health") +async def health(): + return {"status": "ok"} + +@app.post("/generate") +async def generate(request: Request): + data = await request.json() + + # Handle vLLM prompt format: {"prompt": {"prompt_token_ids": [...]}} OR {"prompt": "..."} + prompt_data = data.get("prompt") + if isinstance(prompt_data, dict): + prompt_token_ids = prompt_data.get("prompt_token_ids") + input_ids = torch.tensor([prompt_token_ids]).to("cuda:0") + else: + # Fallback to text prompt + inputs = tokenizer(prompt_data, return_tensors="pt").to("cuda:0") + input_ids = inputs.input_ids + prompt_token_ids = input_ids[0].tolist() + + max_new_tokens = data.get("max_tokens", 256) + temperature = data.get("temperature", 1.0) + top_p = data.get("top_p", 1.0) + n = data.get("n", 1) # Number of completions + + responses = [] + # vLLM-style logprobs (first token of response) + # Atropos expects logprobs: [[{token_id: logprob}, ...]] for each position + + # Simple generation loop for 'n' completions + for _ in range(n): + with torch.no_grad(): + output = model.model.generate( + input_ids=input_ids, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_p=top_p, + do_sample=(temperature > 0), + return_dict_in_generate=True, + output_scores=True, + pad_token_id=tokenizer.pad_token_id, + eos_token_id=tokenizer.eos_token_id + ) + + gen_tokens = output.sequences[0][len(input_ids[0]):].tolist() + + # Calculate logprobs for generated tokens + # scores is a tuple of (max_new_tokens,) tensors of shape (batch, vocab_size) + logprobs_list = [] + for i, score in enumerate(output.scores): + # score is (1, vocab_size) + probs = torch.log_softmax(score, dim=-1) + token_id = gen_tokens[i] + token_logprob = probs[0, token_id].item() + # Format: [{token_id: logprob}] as expected by vllm_server.py:215 + logprobs_list.append([{str(token_id): token_logprob}]) + + responses.append({ + "token_ids": gen_tokens, + "logprobs": logprobs_list, + "finish_reason": "stop" if gen_tokens[-1] == tokenizer.eos_token_id else "length" + }) + + # Mimic vLLM response format + # results["logprobs"] is a list of logprobs_list for each 'n' completion + result = { + "logprobs": [resp["logprobs"] for resp in responses], + "finish_reasons": [resp["finish_reason"] for resp in responses] + } + return JSONResponse(content=result) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=9001) + parser.add_argument("--host", type=str, default="0.0.0.0") + args = parser.parse_args() + + uvicorn.run(app, host=args.host, port=args.port) diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index 2846f14f..a581ca7a 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -132,7 +132,15 @@ from vllm.logger import init_logger # noqa: E402 from vllm.sampling_params import RequestOutputKind, SamplingParams # noqa: E402 from vllm.usage.usage_lib import UsageContext # noqa: E402 from vllm.utils import random_uuid # noqa: E402 -from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402 +# Handle vLLM engine version differences (v0 vs v1) +if os.environ.get("VLLM_USE_V1", "0") == "1": + from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402 +else: + try: + from vllm.engine.async_llm_engine import AsyncLLMEngine as AsyncLLM # noqa: E402 + except ImportError: + # Fallback for older v0 versions + from vllm.engine.async_llm import AsyncLLM # noqa: E402 # Handle vLLM version differences - FlexibleArgumentParser was removed/renamed try: diff --git a/test_shm.sh b/test_shm.sh deleted file mode 100644 index a3b28985..00000000 --- a/test_shm.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash - -# Test script for Atropos-SkyRL SHM Transport -# This script verifies the Zero-Copy SHM bridge and SkyRL adapter logic - -set -e # Exit on error - -echo "==========================================" -echo "Atropos-SkyRL SHM Transport Test" -echo "==========================================" -echo "" - -# Configuration -TEST_NAME="shm_test_$(date +%s)" -echo "Configuration:" -echo " - SHM Segment: $TEST_NAME" -echo "" - -# Run the end-to-end SHM verification suite -echo "Step 1: Running E2E SHM Verification..." -pytest -v atroposlib/tests/test_skyrl_shm_e2e.py -if [ $? -eq 0 ]; then - echo "✓ SHM E2E verification passed" -else - echo "ERROR: SHM E2E verification failed" - exit 1 -fi -echo "" - -# Verify the adapter can be initialized -echo "Step 2: Verifying SkyRLAdapter Initialization..." -python3 -c "from atroposlib.envs.skyrl_adapter import SkyRLAdapter; from atroposlib.envs.base import TransportType; print('✓ SkyRLAdapter successfully imported')" -if [ $? -eq 0 ]; then - echo "✓ Adapter initialization verified" -else - echo "ERROR: Failed to initialize SkyRLAdapter" - exit 1 -fi - -echo "" -echo "==========================================" -echo "✓ All Atropos-side SHM tests passed!" -echo "==========================================" -echo ""