[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot] 2026-04-07 22:25:03 +00:00
parent 463aa79ae8
commit 7c67e0bb19
5 changed files with 108 additions and 55 deletions

View file

@ -871,12 +871,36 @@ class BaseEnv(ABC):
metadata = {
"env": self.name,
"env_id": env_id,
"logprobs": group.get("logprobs") if group.get("logprobs") is not None else None,
"ref_logprobs": group.get("ref_logprobs") if group.get("ref_logprobs") is not None else None,
"distill_token_ids": group.get("distill_token_ids") if group.get("distill_token_ids") is not None else None,
"distill_logprobs": group.get("distill_logprobs") if group.get("distill_logprobs") is not None else None,
"overrides": group.get("overrides") if group.get("overrides") is not None else None,
"group_overrides": group.get("group_overrides") if group.get("group_overrides") is not None else None,
"logprobs": (
group.get("logprobs")
if group.get("logprobs") is not None
else None
),
"ref_logprobs": (
group.get("ref_logprobs")
if group.get("ref_logprobs") is not None
else None
),
"distill_token_ids": (
group.get("distill_token_ids")
if group.get("distill_token_ids") is not None
else None
),
"distill_logprobs": (
group.get("distill_logprobs")
if group.get("distill_logprobs") is not None
else None
),
"overrides": (
group.get("overrides")
if group.get("overrides") is not None
else None
),
"group_overrides": (
group.get("group_overrides")
if group.get("group_overrides") is not None
else None
),
}
self.shm_buffer.write_trajectory(
tokens=group["tokens"][i],

View file

@ -68,7 +68,9 @@ class VLLMServer(APIServer):
Exception,
) as e:
if getattr(self, "_last_error_count", 0) % 60 == 0:
logger.warning(f"💔 VLLM Server Health Check Failed at {self.config.base_url}: {e}")
logger.warning(
f"💔 VLLM Server Health Check Failed at {self.config.base_url}: {e}"
)
self._last_error_count = getattr(self, "_last_error_count", 0) + 1
self.server_healthy = False
await asyncio.sleep(1)

View file

@ -1,48 +1,57 @@
import asyncio
import logging
import os
import sys
import asyncio
import polars as pl
from typing import Any, Dict, List, Optional, Tuple
import polars as pl
# Add atropos to path if not already there
sys.path.append("/root/atropos")
from atroposlib.envs.base import BaseEnv, BaseEnvConfig
from pydantic import Field
from atroposlib.envs.base import BaseEnv, BaseEnvConfig
# Logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(name)s:%(lineno)d - %(message)s'
format="%(asctime)s [%(levelname)s] %(name)s:%(lineno)d - %(message)s",
)
logger = logging.getLogger(__name__)
class SkyRLServerConfig(BaseEnvConfig):
"""
Configuration for the SkyRL Production Server.
"""
dataset_path: str = Field(
default="/root/SkyRL/tests/dummy_fixed_16.parquet",
description="Path to the parquet dataset for task generation."
description="Path to the parquet dataset for task generation.",
)
shm_name: str = Field(default="atropos_shm", description="Name of the SHM segment")
shm_size: int = Field(default=1000, description="Size of the SHM segment in entries")
shm_size: int = Field(
default=1000, description="Size of the SHM segment in entries"
)
class SkyRLServerEnv(BaseEnv):
"""
Production-ready Atropos Environment for SkyRL.
Pulls real tasks from a dataset and performs real vLLM inference.
"""
@classmethod
def config_init(cls):
from atroposlib.envs.server_handling.server_baseline import ServerBaseline
return SkyRLServerConfig(), ServerBaseline()
def __init__(self, **kwargs):
super().__init__(**kwargs)
logger.info(f"Initializing SkyRL Server | dataset: {self.config.dataset_path}")
# Load the dataset
if not os.path.exists(self.config.dataset_path):
logger.error(f"Dataset not found at {self.config.dataset_path}")
@ -51,7 +60,7 @@ class SkyRLServerEnv(BaseEnv):
else:
self.df = pl.read_parquet(self.config.dataset_path)
logger.info(f"Loaded {len(self.df)} prompts from dataset.")
self.current_idx = 0
self.lock = asyncio.Lock()
self.status_dict = {}
@ -64,23 +73,25 @@ class SkyRLServerEnv(BaseEnv):
if self.current_idx >= len(self.df):
self.current_idx = 0
logger.info("Dataset loop finished, restarting from index 0.")
row = self.df.row(self.current_idx, named=True)
prompt = row["prompt"]
uid = str(self.current_idx)
self.current_idx += 1
return prompt, uid
async def collect_trajectory(self, item_tuple: Tuple[Any, str]) -> Tuple[Dict[str, Any], List[Any]]:
async def collect_trajectory(
self, item_tuple: Tuple[Any, str]
) -> Tuple[Dict[str, Any], List[Any]]:
"""
Performs real inference using the Atropos vLLM engine.
Expecting item_tuple to be (prompt, uid) from get_next_item.
"""
item, uid = item_tuple
logger.info(f"Generating trajectory | Task ID: {uid}")
try:
# Use tokens_and_logprobs_completion to get direct token access
# prompt_tokens, output_tokens, output_logprobs, finish_reasons
@ -88,21 +99,23 @@ class SkyRLServerEnv(BaseEnv):
prompt=item,
max_tokens=self.config.max_token_length,
temperature=0.7,
split="train"
split="train",
)
prompt_tokens, output_tokens, output_logprobs, finish_reasons = ret
# Since n=1 by default, we take the first completion
tokens = output_tokens[0]
# Basic Reward Logic:
# In a real scenario, this would call a reward model or a verifier.
# Here we assign 1.0 if any tokens were generated.
score = 1.0 if len(tokens) > 2 else 0.0
logger.info(f"Task {uid} completed | tokens: {len(tokens)} | score: {score}")
logger.info(
f"Task {uid} completed | tokens: {len(tokens)} | score: {score}"
)
# Return (dict, backlog) tuple as expected by BaseEnv
return {
"instance_id": uid,
@ -117,6 +130,7 @@ class SkyRLServerEnv(BaseEnv):
except Exception as e:
logger.error(f"Inference error | Task {uid}: {e}")
import traceback
traceback.print_exc()
# Return empty to allow the loop to continue
return {
@ -141,7 +155,7 @@ class SkyRLServerEnv(BaseEnv):
No-op for SkyRL joint training to avoid connection errors.
"""
logger.info("WandB setup bypassed.")
async def get_server_info(self):
"""
No-op for SkyRL joint training.
@ -170,7 +184,7 @@ class SkyRLServerEnv(BaseEnv):
"""
self.status_dict = {
"current_step": 0,
"queue_size": 0, # Asynchronous sampling - always ready for more
"queue_size": 0, # Asynchronous sampling - always ready for more
"max_group_size": self.config.group_size,
"self_queue_size": 0,
"batches_offpolicy": 0,
@ -178,6 +192,7 @@ class SkyRLServerEnv(BaseEnv):
}
return self.status_dict
if __name__ == "__main__":
# Launch the SkyRLServerEnv via the BaseEnv CLI (serve or process)
SkyRLServerEnv.cli()

View file

@ -1,19 +1,19 @@
import logging
import os
from typing import Any, Dict, List
import torch
import uvicorn
import logging
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from transformers import AutoTokenizer
from skyrl.backends.skyrl_train.workers.model_wrapper import HFModelWrapper
from typing import List, Dict, Any
from transformers import AutoTokenizer
app = FastAPI()
# Logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] SkyRL-Bridge: %(message)s'
level=logging.INFO, format="%(asctime)s [%(levelname)s] SkyRL-Bridge: %(message)s"
)
logger = logging.getLogger(__name__)
@ -21,33 +21,36 @@ logger = logging.getLogger(__name__)
model = None
tokenizer = None
@app.on_event("startup")
async def load_model():
global model, tokenizer
model_path = os.getenv("MODEL_PATH", "Qwen/Qwen2.5-1.5B-Instruct")
logger.info(f"Loading SkyRL-Native Bridge | model: {model_path} | device: cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
model = HFModelWrapper(
model_path,
use_flash_attention_2=False, # Stable SDPA for RTX 3090/CUDA 13
use_flash_attention_2=False, # Stable SDPA for RTX 3090/CUDA 13
bf16=True,
device_map="cuda:0"
device_map="cuda:0",
)
model.eval()
logger.info("SkyRL-Native Bridge is ready.")
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/generate")
async def generate(request: Request):
data = await request.json()
# Handle vLLM prompt format: {"prompt": {"prompt_token_ids": [...]}} OR {"prompt": "..."}
prompt_data = data.get("prompt")
if isinstance(prompt_data, dict):
@ -62,12 +65,12 @@ async def generate(request: Request):
max_new_tokens = data.get("max_tokens", 256)
temperature = data.get("temperature", 1.0)
top_p = data.get("top_p", 1.0)
n = data.get("n", 1) # Number of completions
n = data.get("n", 1) # Number of completions
responses = []
# vLLM-style logprobs (first token of response)
# Atropos expects logprobs: [[{token_id: logprob}, ...]] for each position
# Simple generation loop for 'n' completions
for _ in range(n):
with torch.no_grad():
@ -80,11 +83,11 @@ async def generate(request: Request):
return_dict_in_generate=True,
output_scores=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
eos_token_id=tokenizer.eos_token_id,
)
gen_tokens = output.sequences[0][len(input_ids[0]):].tolist()
gen_tokens = output.sequences[0][len(input_ids[0]) :].tolist()
# Calculate logprobs for generated tokens
# scores is a tuple of (max_new_tokens,) tensors of shape (batch, vocab_size)
logprobs_list = []
@ -95,26 +98,32 @@ async def generate(request: Request):
token_logprob = probs[0, token_id].item()
# Format: [{token_id: logprob}] as expected by vllm_server.py:215
logprobs_list.append([{str(token_id): token_logprob}])
responses.append({
"token_ids": gen_tokens,
"logprobs": logprobs_list,
"finish_reason": "stop" if gen_tokens[-1] == tokenizer.eos_token_id else "length"
})
responses.append(
{
"token_ids": gen_tokens,
"logprobs": logprobs_list,
"finish_reason": (
"stop" if gen_tokens[-1] == tokenizer.eos_token_id else "length"
),
}
)
# Mimic vLLM response format
# results["logprobs"] is a list of logprobs_list for each 'n' completion
result = {
"logprobs": [resp["logprobs"] for resp in responses],
"finish_reasons": [resp["finish_reason"] for resp in responses]
"finish_reasons": [resp["finish_reason"] for resp in responses],
}
return JSONResponse(content=result)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=9001)
parser.add_argument("--host", type=str, default="0.0.0.0")
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port)

View file

@ -132,12 +132,15 @@ from vllm.logger import init_logger # noqa: E402
from vllm.sampling_params import RequestOutputKind, SamplingParams # noqa: E402
from vllm.usage.usage_lib import UsageContext # noqa: E402
from vllm.utils import random_uuid # noqa: E402
# Handle vLLM engine version differences (v0 vs v1)
if os.environ.get("VLLM_USE_V1", "0") == "1":
from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402
else:
try:
from vllm.engine.async_llm_engine import AsyncLLMEngine as AsyncLLM # noqa: E402
from vllm.engine.async_llm_engine import (
AsyncLLMEngine as AsyncLLM,
) # noqa: E402
except ImportError:
# Fallback for older v0 versions
from vllm.engine.async_llm import AsyncLLM # noqa: E402