diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 4cf9df85..87a550f2 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -871,12 +871,36 @@ class BaseEnv(ABC): metadata = { "env": self.name, "env_id": env_id, - "logprobs": group.get("logprobs") if group.get("logprobs") is not None else None, - "ref_logprobs": group.get("ref_logprobs") if group.get("ref_logprobs") is not None else None, - "distill_token_ids": group.get("distill_token_ids") if group.get("distill_token_ids") is not None else None, - "distill_logprobs": group.get("distill_logprobs") if group.get("distill_logprobs") is not None else None, - "overrides": group.get("overrides") if group.get("overrides") is not None else None, - "group_overrides": group.get("group_overrides") if group.get("group_overrides") is not None else None, + "logprobs": ( + group.get("logprobs") + if group.get("logprobs") is not None + else None + ), + "ref_logprobs": ( + group.get("ref_logprobs") + if group.get("ref_logprobs") is not None + else None + ), + "distill_token_ids": ( + group.get("distill_token_ids") + if group.get("distill_token_ids") is not None + else None + ), + "distill_logprobs": ( + group.get("distill_logprobs") + if group.get("distill_logprobs") is not None + else None + ), + "overrides": ( + group.get("overrides") + if group.get("overrides") is not None + else None + ), + "group_overrides": ( + group.get("group_overrides") + if group.get("group_overrides") is not None + else None + ), } self.shm_buffer.write_trajectory( tokens=group["tokens"][i], diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 0d0631dc..f8de8662 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -68,7 +68,9 @@ class VLLMServer(APIServer): Exception, ) as e: if getattr(self, "_last_error_count", 0) % 60 == 0: - logger.warning(f"💔 VLLM Server Health Check Failed at {self.config.base_url}: {e}") + logger.warning( + f"💔 VLLM Server Health Check Failed at {self.config.base_url}: {e}" + ) self._last_error_count = getattr(self, "_last_error_count", 0) + 1 self.server_healthy = False await asyncio.sleep(1) diff --git a/environments/skyrl_server.py b/environments/skyrl_server.py index a66c0ae1..8fb592d3 100644 --- a/environments/skyrl_server.py +++ b/environments/skyrl_server.py @@ -1,48 +1,57 @@ +import asyncio import logging import os import sys -import asyncio -import polars as pl from typing import Any, Dict, List, Optional, Tuple +import polars as pl + # Add atropos to path if not already there sys.path.append("/root/atropos") -from atroposlib.envs.base import BaseEnv, BaseEnvConfig from pydantic import Field +from atroposlib.envs.base import BaseEnv, BaseEnvConfig + # Logging logging.basicConfig( level=logging.INFO, - format='%(asctime)s [%(levelname)s] %(name)s:%(lineno)d - %(message)s' + format="%(asctime)s [%(levelname)s] %(name)s:%(lineno)d - %(message)s", ) logger = logging.getLogger(__name__) + class SkyRLServerConfig(BaseEnvConfig): """ Configuration for the SkyRL Production Server. """ + dataset_path: str = Field( default="/root/SkyRL/tests/dummy_fixed_16.parquet", - description="Path to the parquet dataset for task generation." + description="Path to the parquet dataset for task generation.", ) shm_name: str = Field(default="atropos_shm", description="Name of the SHM segment") - shm_size: int = Field(default=1000, description="Size of the SHM segment in entries") + shm_size: int = Field( + default=1000, description="Size of the SHM segment in entries" + ) + class SkyRLServerEnv(BaseEnv): """ Production-ready Atropos Environment for SkyRL. Pulls real tasks from a dataset and performs real vLLM inference. """ + @classmethod def config_init(cls): from atroposlib.envs.server_handling.server_baseline import ServerBaseline + return SkyRLServerConfig(), ServerBaseline() def __init__(self, **kwargs): super().__init__(**kwargs) logger.info(f"Initializing SkyRL Server | dataset: {self.config.dataset_path}") - + # Load the dataset if not os.path.exists(self.config.dataset_path): logger.error(f"Dataset not found at {self.config.dataset_path}") @@ -51,7 +60,7 @@ class SkyRLServerEnv(BaseEnv): else: self.df = pl.read_parquet(self.config.dataset_path) logger.info(f"Loaded {len(self.df)} prompts from dataset.") - + self.current_idx = 0 self.lock = asyncio.Lock() self.status_dict = {} @@ -64,23 +73,25 @@ class SkyRLServerEnv(BaseEnv): if self.current_idx >= len(self.df): self.current_idx = 0 logger.info("Dataset loop finished, restarting from index 0.") - + row = self.df.row(self.current_idx, named=True) prompt = row["prompt"] uid = str(self.current_idx) - + self.current_idx += 1 - + return prompt, uid - async def collect_trajectory(self, item_tuple: Tuple[Any, str]) -> Tuple[Dict[str, Any], List[Any]]: + async def collect_trajectory( + self, item_tuple: Tuple[Any, str] + ) -> Tuple[Dict[str, Any], List[Any]]: """ Performs real inference using the Atropos vLLM engine. Expecting item_tuple to be (prompt, uid) from get_next_item. """ item, uid = item_tuple logger.info(f"Generating trajectory | Task ID: {uid}") - + try: # Use tokens_and_logprobs_completion to get direct token access # prompt_tokens, output_tokens, output_logprobs, finish_reasons @@ -88,21 +99,23 @@ class SkyRLServerEnv(BaseEnv): prompt=item, max_tokens=self.config.max_token_length, temperature=0.7, - split="train" + split="train", ) - + prompt_tokens, output_tokens, output_logprobs, finish_reasons = ret - + # Since n=1 by default, we take the first completion tokens = output_tokens[0] - + # Basic Reward Logic: # In a real scenario, this would call a reward model or a verifier. # Here we assign 1.0 if any tokens were generated. score = 1.0 if len(tokens) > 2 else 0.0 - - logger.info(f"Task {uid} completed | tokens: {len(tokens)} | score: {score}") - + + logger.info( + f"Task {uid} completed | tokens: {len(tokens)} | score: {score}" + ) + # Return (dict, backlog) tuple as expected by BaseEnv return { "instance_id": uid, @@ -117,6 +130,7 @@ class SkyRLServerEnv(BaseEnv): except Exception as e: logger.error(f"Inference error | Task {uid}: {e}") import traceback + traceback.print_exc() # Return empty to allow the loop to continue return { @@ -141,7 +155,7 @@ class SkyRLServerEnv(BaseEnv): No-op for SkyRL joint training to avoid connection errors. """ logger.info("WandB setup bypassed.") - + async def get_server_info(self): """ No-op for SkyRL joint training. @@ -170,7 +184,7 @@ class SkyRLServerEnv(BaseEnv): """ self.status_dict = { "current_step": 0, - "queue_size": 0, # Asynchronous sampling - always ready for more + "queue_size": 0, # Asynchronous sampling - always ready for more "max_group_size": self.config.group_size, "self_queue_size": 0, "batches_offpolicy": 0, @@ -178,6 +192,7 @@ class SkyRLServerEnv(BaseEnv): } return self.status_dict + if __name__ == "__main__": # Launch the SkyRLServerEnv via the BaseEnv CLI (serve or process) SkyRLServerEnv.cli() diff --git a/example_trainer/skyrl_bridge_server.py b/example_trainer/skyrl_bridge_server.py index 8f516e9e..2d7647db 100644 --- a/example_trainer/skyrl_bridge_server.py +++ b/example_trainer/skyrl_bridge_server.py @@ -1,19 +1,19 @@ +import logging import os +from typing import Any, Dict, List + import torch import uvicorn -import logging from fastapi import FastAPI, Request from fastapi.responses import JSONResponse -from transformers import AutoTokenizer from skyrl.backends.skyrl_train.workers.model_wrapper import HFModelWrapper -from typing import List, Dict, Any +from transformers import AutoTokenizer app = FastAPI() # Logging logging.basicConfig( - level=logging.INFO, - format='%(asctime)s [%(levelname)s] SkyRL-Bridge: %(message)s' + level=logging.INFO, format="%(asctime)s [%(levelname)s] SkyRL-Bridge: %(message)s" ) logger = logging.getLogger(__name__) @@ -21,33 +21,36 @@ logger = logging.getLogger(__name__) model = None tokenizer = None + @app.on_event("startup") async def load_model(): global model, tokenizer model_path = os.getenv("MODEL_PATH", "Qwen/Qwen2.5-1.5B-Instruct") logger.info(f"Loading SkyRL-Native Bridge | model: {model_path} | device: cuda:0") - + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id - + model = HFModelWrapper( model_path, - use_flash_attention_2=False, # Stable SDPA for RTX 3090/CUDA 13 + use_flash_attention_2=False, # Stable SDPA for RTX 3090/CUDA 13 bf16=True, - device_map="cuda:0" + device_map="cuda:0", ) model.eval() logger.info("SkyRL-Native Bridge is ready.") + @app.get("/health") async def health(): return {"status": "ok"} + @app.post("/generate") async def generate(request: Request): data = await request.json() - + # Handle vLLM prompt format: {"prompt": {"prompt_token_ids": [...]}} OR {"prompt": "..."} prompt_data = data.get("prompt") if isinstance(prompt_data, dict): @@ -62,12 +65,12 @@ async def generate(request: Request): max_new_tokens = data.get("max_tokens", 256) temperature = data.get("temperature", 1.0) top_p = data.get("top_p", 1.0) - n = data.get("n", 1) # Number of completions - + n = data.get("n", 1) # Number of completions + responses = [] # vLLM-style logprobs (first token of response) # Atropos expects logprobs: [[{token_id: logprob}, ...]] for each position - + # Simple generation loop for 'n' completions for _ in range(n): with torch.no_grad(): @@ -80,11 +83,11 @@ async def generate(request: Request): return_dict_in_generate=True, output_scores=True, pad_token_id=tokenizer.pad_token_id, - eos_token_id=tokenizer.eos_token_id + eos_token_id=tokenizer.eos_token_id, ) - - gen_tokens = output.sequences[0][len(input_ids[0]):].tolist() - + + gen_tokens = output.sequences[0][len(input_ids[0]) :].tolist() + # Calculate logprobs for generated tokens # scores is a tuple of (max_new_tokens,) tensors of shape (batch, vocab_size) logprobs_list = [] @@ -95,26 +98,32 @@ async def generate(request: Request): token_logprob = probs[0, token_id].item() # Format: [{token_id: logprob}] as expected by vllm_server.py:215 logprobs_list.append([{str(token_id): token_logprob}]) - - responses.append({ - "token_ids": gen_tokens, - "logprobs": logprobs_list, - "finish_reason": "stop" if gen_tokens[-1] == tokenizer.eos_token_id else "length" - }) + + responses.append( + { + "token_ids": gen_tokens, + "logprobs": logprobs_list, + "finish_reason": ( + "stop" if gen_tokens[-1] == tokenizer.eos_token_id else "length" + ), + } + ) # Mimic vLLM response format # results["logprobs"] is a list of logprobs_list for each 'n' completion result = { "logprobs": [resp["logprobs"] for resp in responses], - "finish_reasons": [resp["finish_reason"] for resp in responses] + "finish_reasons": [resp["finish_reason"] for resp in responses], } return JSONResponse(content=result) + if __name__ == "__main__": import argparse + parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=9001) parser.add_argument("--host", type=str, default="0.0.0.0") args = parser.parse_args() - + uvicorn.run(app, host=args.host, port=args.port) diff --git a/example_trainer/vllm_api_server.py b/example_trainer/vllm_api_server.py index a581ca7a..9a73784d 100644 --- a/example_trainer/vllm_api_server.py +++ b/example_trainer/vllm_api_server.py @@ -132,12 +132,15 @@ from vllm.logger import init_logger # noqa: E402 from vllm.sampling_params import RequestOutputKind, SamplingParams # noqa: E402 from vllm.usage.usage_lib import UsageContext # noqa: E402 from vllm.utils import random_uuid # noqa: E402 + # Handle vLLM engine version differences (v0 vs v1) if os.environ.get("VLLM_USE_V1", "0") == "1": from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402 else: try: - from vllm.engine.async_llm_engine import AsyncLLMEngine as AsyncLLM # noqa: E402 + from vllm.engine.async_llm_engine import ( + AsyncLLMEngine as AsyncLLM, + ) # noqa: E402 except ImportError: # Fallback for older v0 versions from vllm.engine.async_llm import AsyncLLM # noqa: E402