diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index 76851dea..41ed3bf5 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -15,8 +15,8 @@ from atroposlib.envs.server_handling.server_baseline import ( ServerBaseline, ) from atroposlib.envs.server_handling.server_harness import ServerHarness -from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer from atroposlib.envs.server_handling.sglang_server import SGLangServer +from atroposlib.envs.server_handling.trl_vllm_server import TrlVllmServer class ServerManagerConfig(BaseModel): @@ -289,7 +289,9 @@ class ServerManager: most_available_server_num_slots = ( server.sem._value if is_train else server.eval_sem._value ) - return await self.servers[most_available_server].tokens_and_logprobs_completion(**kwargs) + return await self.servers[most_available_server].tokens_and_logprobs_completion( + **kwargs + ) @asynccontextmanager async def dedicated_server(self) -> AsyncGenerator[OpenAIServer, None]: diff --git a/atroposlib/envs/server_handling/sglang_server.py b/atroposlib/envs/server_handling/sglang_server.py index e0e44327..1a2f1362 100644 --- a/atroposlib/envs/server_handling/sglang_server.py +++ b/atroposlib/envs/server_handling/sglang_server.py @@ -6,10 +6,10 @@ import openai from openai.types.chat.chat_completion import ChatCompletion from openai.types.completion import Completion from pydantic_cli import FailedExecutionException +from transformers import AutoTokenizer from atroposlib.envs.constants import NAMESPACE_SEP, OPENAI_NAMESPACE from atroposlib.envs.server_handling.server_baseline import APIServer, APIServerConfig -from transformers import AutoTokenizer class SGLangServer(APIServer): @@ -171,12 +171,17 @@ class SGLangServer(APIServer): # Make async request to SGLang /generate endpoint import aiohttp + async with aiohttp.ClientSession() as session: async with session.post( f"{self.config.base_url.replace('/v1', '')}/generate", json=request_data, - headers={"Authorization": f"Bearer {self.config.api_key}"} if self.config.api_key else {}, - timeout=aiohttp.ClientTimeout(total=self.config.timeout) + headers=( + {"Authorization": f"Bearer {self.config.api_key}"} + if self.config.api_key + else {} + ), + timeout=aiohttp.ClientTimeout(total=self.config.timeout), ) as response: response.raise_for_status() results = await response.json() @@ -194,8 +199,12 @@ class SGLangServer(APIServer): # Get output logprobs - extract just the logprob values output_token_logprobs = meta_info.get("output_token_logprobs", []) - logprobs = [item[0] for item in output_token_logprobs] # Extract logprob from (logprob, token_id, text) tuples - output_ids = [item[1] for item in output_token_logprobs] # Extract token ID from (logprob, token_id, text) tuples + logprobs = [ + item[0] for item in output_token_logprobs + ] # Extract logprob from (logprob, token_id, text) tuples + output_ids = [ + item[1] for item in output_token_logprobs + ] # Extract token ID from (logprob, token_id, text) tuples # Get finish reason finish_reason = meta_info.get("finish_reason", None) @@ -204,7 +213,12 @@ class SGLangServer(APIServer): output_logprobs_list.append(logprobs) finish_reasons_list.append(finish_reason) - return prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons_list + return ( + prompt_tokens, + output_tokens_list, + output_logprobs_list, + finish_reasons_list, + ) def resolve_openai_configs( @@ -301,13 +315,15 @@ if __name__ == "__main__": # Test the tokens_and_logprobs_completion method print("Testing tokens_and_logprobs_completion...") try: - prompt_tokens, output_tokens, output_logprobs, finish_reasons = await server.tokens_and_logprobs_completion( - prompt="The capital of France is", - n=4, - max_tokens=32, - temperature=1.0, - top_p=1.0, - stop=["User:", "Human:", "Assistant:", ""], + prompt_tokens, output_tokens, output_logprobs, finish_reasons = ( + await server.tokens_and_logprobs_completion( + prompt="The capital of France is", + n=4, + max_tokens=32, + temperature=1.0, + top_p=1.0, + stop=["User:", "Human:", "Assistant:", ""], + ) ) print(f"\nResults:") @@ -317,12 +333,15 @@ if __name__ == "__main__": print(f"Finish reasons: {finish_reasons}") print(f"\nNumber of completions: {len(output_tokens)}") print(f"Output length: {[len(tokens) for tokens in output_tokens]}") - responses = '\n\n'.join([server.tokenizer.decode(tokens) for tokens in output_tokens]) + responses = "\n\n".join( + [server.tokenizer.decode(tokens) for tokens in output_tokens] + ) print(f"Responses:\n-{responses}") except Exception as e: print(f"Error: {e}") import traceback + traceback.print_exc() # Run the test diff --git a/environments/math_server_zero.py b/environments/math_server_zero.py index 38991410..7596e6cf 100644 --- a/environments/math_server_zero.py +++ b/environments/math_server_zero.py @@ -18,11 +18,11 @@ from pydantic import Field from tqdm.asyncio import tqdm_asyncio from atroposlib.envs.base import ( - ServerBaseline, BaseEnv, BaseEnvConfig, EvalHandlingEnum, ScoredDataGroup, + ServerBaseline, ) prompt_format = ( @@ -49,7 +49,8 @@ class RSConfig(BaseEnvConfig): 0.0, description="The percentage of items to have length penalty" ) start_tok_length: int = Field( - 8192, description="The starting length of the token length, scaled linearly to the max_token_length" + 8192, + description="The starting length of the token length, scaled linearly to the max_token_length", ) @@ -149,12 +150,12 @@ class MathEnv(BaseEnv): wandb_name="math", eval_handling=EvalHandlingEnum.LIMIT_TRAIN, eval_limit_ratio=0.1, - max_num_workers_per_node=24 + max_num_workers_per_node=24, ) server_configs = ServerBaseline( model_name="Qwen/Qwen2.5-7B", num_requests_for_eval=256, # since evaling only on one... - server_type="sglang" + server_type="sglang", ) return env_config, server_configs @@ -240,10 +241,7 @@ class MathEnv(BaseEnv): name, ) ) - for name, t_dataset in zip( - ['olympiad'], - [olympiad_test_data] - ): + for name, t_dataset in zip(["olympiad"], [olympiad_test_data]): for item in t_dataset: self.test.append( ( @@ -317,18 +315,22 @@ class MathEnv(BaseEnv): curr_length = int(curr_length * (self.curr_step / self.config.total_steps)) curr_length += self.config.start_tok_length thinking_len = min(thinking_len, curr_length) - prompt_tokens, out_tokens, out_logprobs, finish_reasons = await self.server.tokens_and_logprobs_completion( - prompt=user_prompt, - n=self.config.group_size, - max_tokens=thinking_len, - temperature=1.0, - top_p=1.0, - stop=stop_list, + prompt_tokens, out_tokens, out_logprobs, finish_reasons = ( + await self.server.tokens_and_logprobs_completion( + prompt=user_prompt, + n=self.config.group_size, + max_tokens=thinking_len, + temperature=1.0, + top_p=1.0, + stop=stop_list, + ) ) # print(completions, flush=True) to_score = list() to_backlog = list() - for i, (tokens, logprobs, finish_reason) in enumerate(zip(out_tokens, out_logprobs, finish_reasons)): + for i, (tokens, logprobs, finish_reason) in enumerate( + zip(out_tokens, out_logprobs, finish_reasons) + ): message = self.tokenizer.decode(prompt_tokens + tokens) to_score.append( ( @@ -378,7 +380,7 @@ class MathEnv(BaseEnv): user_prompt_tokens = item[3] out_toks = item[4] out_logps = item[5] - if item[2]['type'] == "length": + if item[2]["type"] == "length": reward = False if self.config.mask_too_long_completions: scores["overrides"][-1]["set_advantage_to_zero"] = True @@ -392,7 +394,9 @@ class MathEnv(BaseEnv): masks = masks + out_toks inf_logp = [0 for _ in range(len(user_prompt_tokens))] inf_logp = inf_logp + out_logps - assert len(inf_logp) == len(masks), f"{len(inf_logp)}, {len(masks)} mismatch" + assert len(inf_logp) == len( + masks + ), f"{len(inf_logp)}, {len(masks)} mismatch" user_prompt = resp.split("")[0] messages = [ {"role": "user", "content": user_prompt},