diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index ac134300..3a0fb999 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -4,7 +4,7 @@ import time import uuid from typing import Any, Dict, List, Optional -from fastapi import FastAPI, Request, status +from fastapi import FastAPI, status from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.responses import PlainTextResponse @@ -351,7 +351,7 @@ async def info(): @app.get("/batch") -async def get_batch(request: Request): +async def get_batch(): # Check if trainer has registered first if not hasattr(app.state, "started"): return { @@ -363,27 +363,8 @@ async def get_batch(request: Request): if not app.state.started: app.state.started = True - client = request.client - client_addr = ( - f"{client.host}:{client.port}" if client is not None else "unknown-client" - ) - client_tag = request.headers.get("x-atropos-client", "unknown") - client_pid = request.headers.get("x-atropos-pid", "unknown") - if len(app.state.curr_batch) > 0: - curr_batch = app.state.curr_batch.pop() - logger.warning( - "API /batch returning prebuilt batch to client=%s pid=%s addr=%s: " - "groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s", - client_tag, - client_pid, - client_addr, - len(curr_batch), - sum(len(x["tokens"]) for x in curr_batch), - len(app.state.curr_batch), - len(app.state.queue), - ) - return {"batch": curr_batch} + return {"batch": app.state.curr_batch.pop()} else: new_batches = [] # Check if any envs have minimum allocations @@ -413,21 +394,6 @@ async def get_batch(request: Request): ) steps_to_take = len(new_batches) if steps_to_take == 0: - now = time.time() - last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0) - if now - last_empty_log > 30: - logger.warning( - "API /batch no full batch ready for client=%s pid=%s addr=%s: " - "queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s", - client_tag, - client_pid, - client_addr, - len(app.state.queue), - sum(len(x.get("tokens", [])) for x in app.state.queue), - len(app.state.curr_batch), - getattr(app.state, "batchsize", -1), - ) - app.state._last_empty_batch_log = now return {"batch": None} app.state.status_dict["step"] += steps_to_take # chunk it @@ -435,18 +401,9 @@ async def get_batch(request: Request): app.state.curr_batch.append(batch) curr_batch = app.state.curr_batch.pop() # check length before sending - logger.warning( - "API /batch built %s trainer batch(es); returning one to client=%s pid=%s addr=%s " - "with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s", - steps_to_take, - client_tag, - client_pid, - client_addr, - len(curr_batch), + logger.info( + "Sending batch of %s sequences", sum(len(x["tokens"]) for x in curr_batch), - len(app.state.curr_batch), - len(app.state.queue), - app.state.status_dict["step"], ) return {"batch": curr_batch} diff --git a/atroposlib/envs/base.py b/atroposlib/envs/base.py index 7aa391ba..3d3b6c20 100644 --- a/atroposlib/envs/base.py +++ b/atroposlib/envs/base.py @@ -907,7 +907,7 @@ class BaseEnv(ABC): "ensure your trainer handles this appropriately." ) elif abort_on_any_max_length_exceeded and any( - [len(x) > self.max_token_len for x in group["tokens"]] + [len(x) >= self.max_token_len for x in group["tokens"]] ): logger.warning("Token length is too long in a group, skipping...") continue diff --git a/atroposlib/envs/server_handling/managed_server.py b/atroposlib/envs/server_handling/managed_server.py index a8e97077..9d46f265 100644 --- a/atroposlib/envs/server_handling/managed_server.py +++ b/atroposlib/envs/server_handling/managed_server.py @@ -447,33 +447,14 @@ class ManagedServer: if not self.track_tree and self.tokenizer is not None: input_ids = self._compute_input_ids(prompt, extending_node) completion_kwargs["input_ids"] = input_ids - logger.warning( - "managed_server chat_completion prepared input_ids=%s extending=%s", - len(input_ids), - extending_node is not None, - ) - else: - logger.warning( - "managed_server chat_completion using prompt passthrough track_tree=%s tokenizer=%s", - self.track_tree, - self.tokenizer is not None, - ) # Call the tokens and logprobs wrapper directly - logger.warning( - "managed_server chat_completion calling backend completion wrapper" - ) ( prompt_tokens, output_tokens_list, output_logprobs_list, finish_reasons, ) = await self.server.tokens_and_logprobs_completion(**completion_kwargs) - logger.warning( - "managed_server chat_completion backend returned prompt_tokens=%s outputs=%s", - len(prompt_tokens), - len(output_tokens_list), - ) # Track each completion and build choices n = len(output_tokens_list) diff --git a/atroposlib/envs/server_handling/server_manager.py b/atroposlib/envs/server_handling/server_manager.py index b24698a6..d34f69c9 100644 --- a/atroposlib/envs/server_handling/server_manager.py +++ b/atroposlib/envs/server_handling/server_manager.py @@ -106,13 +106,6 @@ class ServerManager: self.servers = [ServerHarness()] return if not isinstance(configs, list): - logger.warning( - "ServerManager: configs is NOT a list (type=%s). " - "Using auto-generated URLs (template mode). " - "Passed base_url=%s will be IGNORED.", - type(configs).__name__, - getattr(configs, "base_url", "N/A"), - ) urls = [] if os.environ.get("SLURM_JOB_NODELIST", None) is not None: nodelist = ( @@ -155,21 +148,11 @@ class ServerManager: server_class(config, reasoning_config=reasoning_config) for config in openai_configs ] - logger.warning( - "ServerManager: auto-generated %s server(s) at URLs: %s", - len(self.servers), - [c.base_url for c in openai_configs], - ) elif not slurm: self.servers = [ server_class(config, reasoning_config=reasoning_config) for config in configs ] - logger.warning( - "ServerManager: using %s explicit config(s) at URLs: %s", - len(self.servers), - [c.base_url for c in configs], - ) else: nodelist = ( os.popen(f'scontrol show hostnames {os.environ["SLURM_JOB_NODELIST"]}') diff --git a/atroposlib/envs/server_handling/vllm_server.py b/atroposlib/envs/server_handling/vllm_server.py index 18b8333e..acc26830 100644 --- a/atroposlib/envs/server_handling/vllm_server.py +++ b/atroposlib/envs/server_handling/vllm_server.py @@ -193,14 +193,6 @@ class VLLMServer(APIServer): # Prepare request for VLLM native API request_data = {"prompt": {"prompt_token_ids": prompt_tokens}, "logprobs": 0} request_data.update(kwargs) - logger.warning( - "vllm_server completion POST start base_url=%s prompt_tokens=%s n=%s max_tokens=%s temperature=%s", - self.config.base_url, - len(prompt_tokens), - request_data.get("n"), - request_data.get("max_tokens"), - request_data.get("temperature"), - ) # Make async request to VLLM /generate endpoint async with aiohttp.ClientSession() as session: @@ -216,11 +208,6 @@ class VLLMServer(APIServer): ) as response: response.raise_for_status() results = await response.json() - logger.warning( - "vllm_server completion POST done outputs=%s finish_reasons=%s", - len(results.get("logprobs", [])), - len(results.get("finish_reasons", [])), - ) output_tokens_list = [] output_logprobs_list = [] finish_reasons_list = [] @@ -330,13 +317,6 @@ class VLLMServer(APIServer): request_data["temperature"] = 0.0 request_data["top_p"] = 1.0 request_data.setdefault("max_tokens", 1) - logger.warning( - "vllm_server get_logprobs POST start base_url=%s prompt_tokens=%s top_k=%s max_tokens=%s", - self.config.base_url, - len(prompt_tokens), - top_k, - request_data.get("max_tokens"), - ) async with aiohttp.ClientSession() as session: async with session.post( @@ -351,10 +331,6 @@ class VLLMServer(APIServer): ) as response: response.raise_for_status() results = await response.json() - logger.warning( - "vllm_server get_logprobs POST done prompt_logprobs_present=%s", - results.get("prompt_logprobs") is not None, - ) raw_prompt_logprobs = results.get("prompt_logprobs") if raw_prompt_logprobs is None: @@ -451,10 +427,6 @@ def resolve_openai_configs( elif isinstance(default_server_configs, list): server_configs = [final_openai_config] else: - logger.warning( - f"Unexpected type for default_server_configs: {type(default_server_configs)}. " - f"Proceeding with single OpenAI server configuration based on merged settings." - ) server_configs = [final_openai_config] return server_configs diff --git a/example_trainer/api.py b/example_trainer/api.py index fe0ac38a..dc51af4f 100644 --- a/example_trainer/api.py +++ b/example_trainer/api.py @@ -100,34 +100,15 @@ def get_batch(url: str = "http://localhost:8000"): Raises: RuntimeError: If trainer is not registered or other API error """ - try: - response = requests.get( - f"{url}/batch", - headers={ - "X-Atropos-Client": "trainer", - "X-Atropos-Pid": str(os.getpid()), - }, - timeout=10, - ) - print( - f" [Trainer/API] GET /batch status={response.status_code}", - flush=True, - ) - data = response.json() - batch = data.get("batch") - if batch is None: - print(" [Trainer/API] parsed batch=None", flush=True) - else: - num_groups = len(batch) - num_sequences = sum(len(item["tokens"]) for item in batch) - print( - " [Trainer/API] parsed batch payload: " - f"groups={num_groups} sequences={num_sequences}", - flush=True, - ) - except Exception as exc: - print(f" [Trainer/API] GET /batch failed: {exc!r}", flush=True) - raise + response = requests.get( + f"{url}/batch", + headers={ + "X-Atropos-Client": "trainer", + "X-Atropos-Pid": str(os.getpid()), + }, + timeout=10, + ) + data = response.json() # Check if there was an error (trainer not registered) if data.get("status") == "error": diff --git a/example_trainer/data.py b/example_trainer/data.py index 4823eb64..0aa1a88a 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -377,61 +377,12 @@ def get_data( - inference_logprob_batches are aligned with labels for proper GRPO loss computation """ batches = [] - _logged_logprob_warning = False - empty_polls = 0 while True: data = get_batch(url=atropos_url) if data["batch"] is not None: - empty_polls = 0 - num_groups = len(data["batch"]) - num_sequences = sum(len(item["tokens"]) for item in data["batch"]) - max_seq_len = max( - max(len(seq) for seq in item["tokens"]) for item in data["batch"] - ) - print( - " [Data] received API batch: " - f"groups={num_groups} sequences={num_sequences} max_seq_len={max_seq_len}", - flush=True, - ) - # DEBUG: Check if inference_logprobs exists in the data - if not _logged_logprob_warning: - has_logprobs = any( - "inference_logprobs" in item for item in data["batch"] - ) - if has_logprobs: - # Check if they're non-empty - sample_item = next( - ( - item - for item in data["batch"] - if "inference_logprobs" in item - ), - None, - ) - if sample_item and sample_item.get("inference_logprobs"): - sample_lp = ( - sample_item["inference_logprobs"][0] - if sample_item["inference_logprobs"] - else [] - ) - print( - f" [Data] ✓ inference_logprobs found in batch (sample len: {len(sample_lp)})" - ) - else: - print( - " [Data] ⚠ inference_logprobs key exists but is empty!" - ) - else: - print(" [Data] ⚠ NO inference_logprobs in batch data!") - print( - f" [Data] Keys in first item: {list(data['batch'][0].keys())}" - ) - _logged_logprob_warning = True - # Process and accumulate batches (now includes batched inference logprobs) - print(" [Data] padding / batching API payload...", flush=True) ( token_batches, label_batches, @@ -441,12 +392,6 @@ def get_data( distill_token_id_batches, distill_logprob_batches, ) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) - batch_shapes = [tuple(tb.shape) for tb in token_batches] - print( - " [Data] pad_data_to_good_offset done: " - f"micro_batches={len(token_batches)} token_batch_shapes={batch_shapes}", - flush=True, - ) # Include inference logprob batches in the tuple batches.append( @@ -463,17 +408,7 @@ def get_data( elif len(batches) > 0: # Return accumulated batches when no more data - print( - f" [Data] returning {len(batches)} assembled trainer batch tuple(s)", - flush=True, - ) return batches, None else: # Wait for data - empty_polls += 1 - if empty_polls == 1 or empty_polls % 30 == 0: - print( - f" [Data] no batch ready yet (polls_without_data={empty_polls})", - flush=True, - ) time.sleep(1)