mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
463aa79ae8
commit
7c67e0bb19
5 changed files with 108 additions and 55 deletions
|
|
@ -871,12 +871,36 @@ class BaseEnv(ABC):
|
|||
metadata = {
|
||||
"env": self.name,
|
||||
"env_id": env_id,
|
||||
"logprobs": group.get("logprobs") if group.get("logprobs") is not None else None,
|
||||
"ref_logprobs": group.get("ref_logprobs") if group.get("ref_logprobs") is not None else None,
|
||||
"distill_token_ids": group.get("distill_token_ids") if group.get("distill_token_ids") is not None else None,
|
||||
"distill_logprobs": group.get("distill_logprobs") if group.get("distill_logprobs") is not None else None,
|
||||
"overrides": group.get("overrides") if group.get("overrides") is not None else None,
|
||||
"group_overrides": group.get("group_overrides") if group.get("group_overrides") is not None else None,
|
||||
"logprobs": (
|
||||
group.get("logprobs")
|
||||
if group.get("logprobs") is not None
|
||||
else None
|
||||
),
|
||||
"ref_logprobs": (
|
||||
group.get("ref_logprobs")
|
||||
if group.get("ref_logprobs") is not None
|
||||
else None
|
||||
),
|
||||
"distill_token_ids": (
|
||||
group.get("distill_token_ids")
|
||||
if group.get("distill_token_ids") is not None
|
||||
else None
|
||||
),
|
||||
"distill_logprobs": (
|
||||
group.get("distill_logprobs")
|
||||
if group.get("distill_logprobs") is not None
|
||||
else None
|
||||
),
|
||||
"overrides": (
|
||||
group.get("overrides")
|
||||
if group.get("overrides") is not None
|
||||
else None
|
||||
),
|
||||
"group_overrides": (
|
||||
group.get("group_overrides")
|
||||
if group.get("group_overrides") is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
self.shm_buffer.write_trajectory(
|
||||
tokens=group["tokens"][i],
|
||||
|
|
|
|||
|
|
@ -68,7 +68,9 @@ class VLLMServer(APIServer):
|
|||
Exception,
|
||||
) as e:
|
||||
if getattr(self, "_last_error_count", 0) % 60 == 0:
|
||||
logger.warning(f"💔 VLLM Server Health Check Failed at {self.config.base_url}: {e}")
|
||||
logger.warning(
|
||||
f"💔 VLLM Server Health Check Failed at {self.config.base_url}: {e}"
|
||||
)
|
||||
self._last_error_count = getattr(self, "_last_error_count", 0) + 1
|
||||
self.server_healthy = False
|
||||
await asyncio.sleep(1)
|
||||
|
|
|
|||
|
|
@ -1,42 +1,51 @@
|
|||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import asyncio
|
||||
import polars as pl
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import polars as pl
|
||||
|
||||
# Add atropos to path if not already there
|
||||
sys.path.append("/root/atropos")
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig
|
||||
from pydantic import Field
|
||||
|
||||
from atroposlib.envs.base import BaseEnv, BaseEnvConfig
|
||||
|
||||
# Logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] %(name)s:%(lineno)d - %(message)s'
|
||||
format="%(asctime)s [%(levelname)s] %(name)s:%(lineno)d - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SkyRLServerConfig(BaseEnvConfig):
|
||||
"""
|
||||
Configuration for the SkyRL Production Server.
|
||||
"""
|
||||
|
||||
dataset_path: str = Field(
|
||||
default="/root/SkyRL/tests/dummy_fixed_16.parquet",
|
||||
description="Path to the parquet dataset for task generation."
|
||||
description="Path to the parquet dataset for task generation.",
|
||||
)
|
||||
shm_name: str = Field(default="atropos_shm", description="Name of the SHM segment")
|
||||
shm_size: int = Field(default=1000, description="Size of the SHM segment in entries")
|
||||
shm_size: int = Field(
|
||||
default=1000, description="Size of the SHM segment in entries"
|
||||
)
|
||||
|
||||
|
||||
class SkyRLServerEnv(BaseEnv):
|
||||
"""
|
||||
Production-ready Atropos Environment for SkyRL.
|
||||
Pulls real tasks from a dataset and performs real vLLM inference.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def config_init(cls):
|
||||
from atroposlib.envs.server_handling.server_baseline import ServerBaseline
|
||||
|
||||
return SkyRLServerConfig(), ServerBaseline()
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
|
@ -73,7 +82,9 @@ class SkyRLServerEnv(BaseEnv):
|
|||
|
||||
return prompt, uid
|
||||
|
||||
async def collect_trajectory(self, item_tuple: Tuple[Any, str]) -> Tuple[Dict[str, Any], List[Any]]:
|
||||
async def collect_trajectory(
|
||||
self, item_tuple: Tuple[Any, str]
|
||||
) -> Tuple[Dict[str, Any], List[Any]]:
|
||||
"""
|
||||
Performs real inference using the Atropos vLLM engine.
|
||||
Expecting item_tuple to be (prompt, uid) from get_next_item.
|
||||
|
|
@ -88,7 +99,7 @@ class SkyRLServerEnv(BaseEnv):
|
|||
prompt=item,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.7,
|
||||
split="train"
|
||||
split="train",
|
||||
)
|
||||
|
||||
prompt_tokens, output_tokens, output_logprobs, finish_reasons = ret
|
||||
|
|
@ -101,7 +112,9 @@ class SkyRLServerEnv(BaseEnv):
|
|||
# Here we assign 1.0 if any tokens were generated.
|
||||
score = 1.0 if len(tokens) > 2 else 0.0
|
||||
|
||||
logger.info(f"Task {uid} completed | tokens: {len(tokens)} | score: {score}")
|
||||
logger.info(
|
||||
f"Task {uid} completed | tokens: {len(tokens)} | score: {score}"
|
||||
)
|
||||
|
||||
# Return (dict, backlog) tuple as expected by BaseEnv
|
||||
return {
|
||||
|
|
@ -117,6 +130,7 @@ class SkyRLServerEnv(BaseEnv):
|
|||
except Exception as e:
|
||||
logger.error(f"Inference error | Task {uid}: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
# Return empty to allow the loop to continue
|
||||
return {
|
||||
|
|
@ -170,7 +184,7 @@ class SkyRLServerEnv(BaseEnv):
|
|||
"""
|
||||
self.status_dict = {
|
||||
"current_step": 0,
|
||||
"queue_size": 0, # Asynchronous sampling - always ready for more
|
||||
"queue_size": 0, # Asynchronous sampling - always ready for more
|
||||
"max_group_size": self.config.group_size,
|
||||
"self_queue_size": 0,
|
||||
"batches_offpolicy": 0,
|
||||
|
|
@ -178,6 +192,7 @@ class SkyRLServerEnv(BaseEnv):
|
|||
}
|
||||
return self.status_dict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Launch the SkyRLServerEnv via the BaseEnv CLI (serve or process)
|
||||
SkyRLServerEnv.cli()
|
||||
|
|
|
|||
|
|
@ -1,19 +1,19 @@
|
|||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
import logging
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from transformers import AutoTokenizer
|
||||
from skyrl.backends.skyrl_train.workers.model_wrapper import HFModelWrapper
|
||||
from typing import List, Dict, Any
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s [%(levelname)s] SkyRL-Bridge: %(message)s'
|
||||
level=logging.INFO, format="%(asctime)s [%(levelname)s] SkyRL-Bridge: %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -21,6 +21,7 @@ logger = logging.getLogger(__name__)
|
|||
model = None
|
||||
tokenizer = None
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def load_model():
|
||||
global model, tokenizer
|
||||
|
|
@ -33,17 +34,19 @@ async def load_model():
|
|||
|
||||
model = HFModelWrapper(
|
||||
model_path,
|
||||
use_flash_attention_2=False, # Stable SDPA for RTX 3090/CUDA 13
|
||||
use_flash_attention_2=False, # Stable SDPA for RTX 3090/CUDA 13
|
||||
bf16=True,
|
||||
device_map="cuda:0"
|
||||
device_map="cuda:0",
|
||||
)
|
||||
model.eval()
|
||||
logger.info("SkyRL-Native Bridge is ready.")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request):
|
||||
data = await request.json()
|
||||
|
|
@ -62,7 +65,7 @@ async def generate(request: Request):
|
|||
max_new_tokens = data.get("max_tokens", 256)
|
||||
temperature = data.get("temperature", 1.0)
|
||||
top_p = data.get("top_p", 1.0)
|
||||
n = data.get("n", 1) # Number of completions
|
||||
n = data.get("n", 1) # Number of completions
|
||||
|
||||
responses = []
|
||||
# vLLM-style logprobs (first token of response)
|
||||
|
|
@ -80,10 +83,10 @@ async def generate(request: Request):
|
|||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
gen_tokens = output.sequences[0][len(input_ids[0]):].tolist()
|
||||
gen_tokens = output.sequences[0][len(input_ids[0]) :].tolist()
|
||||
|
||||
# Calculate logprobs for generated tokens
|
||||
# scores is a tuple of (max_new_tokens,) tensors of shape (batch, vocab_size)
|
||||
|
|
@ -96,22 +99,28 @@ async def generate(request: Request):
|
|||
# Format: [{token_id: logprob}] as expected by vllm_server.py:215
|
||||
logprobs_list.append([{str(token_id): token_logprob}])
|
||||
|
||||
responses.append({
|
||||
"token_ids": gen_tokens,
|
||||
"logprobs": logprobs_list,
|
||||
"finish_reason": "stop" if gen_tokens[-1] == tokenizer.eos_token_id else "length"
|
||||
})
|
||||
responses.append(
|
||||
{
|
||||
"token_ids": gen_tokens,
|
||||
"logprobs": logprobs_list,
|
||||
"finish_reason": (
|
||||
"stop" if gen_tokens[-1] == tokenizer.eos_token_id else "length"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Mimic vLLM response format
|
||||
# results["logprobs"] is a list of logprobs_list for each 'n' completion
|
||||
result = {
|
||||
"logprobs": [resp["logprobs"] for resp in responses],
|
||||
"finish_reasons": [resp["finish_reason"] for resp in responses]
|
||||
"finish_reasons": [resp["finish_reason"] for resp in responses],
|
||||
}
|
||||
return JSONResponse(content=result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=9001)
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
|
|
|
|||
|
|
@ -132,12 +132,15 @@ from vllm.logger import init_logger # noqa: E402
|
|||
from vllm.sampling_params import RequestOutputKind, SamplingParams # noqa: E402
|
||||
from vllm.usage.usage_lib import UsageContext # noqa: E402
|
||||
from vllm.utils import random_uuid # noqa: E402
|
||||
|
||||
# Handle vLLM engine version differences (v0 vs v1)
|
||||
if os.environ.get("VLLM_USE_V1", "0") == "1":
|
||||
from vllm.v1.engine.async_llm import AsyncLLM # noqa: E402
|
||||
else:
|
||||
try:
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine as AsyncLLM # noqa: E402
|
||||
from vllm.engine.async_llm_engine import (
|
||||
AsyncLLMEngine as AsyncLLM,
|
||||
) # noqa: E402
|
||||
except ImportError:
|
||||
# Fallback for older v0 versions
|
||||
from vllm.engine.async_llm import AsyncLLM # noqa: E402
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue