diff --git a/atroposlib/api/server.py b/atroposlib/api/server.py index 3a0fb999..0ca0f02d 100644 --- a/atroposlib/api/server.py +++ b/atroposlib/api/server.py @@ -364,7 +364,15 @@ async def get_batch(): app.state.started = True if len(app.state.curr_batch) > 0: - return {"batch": app.state.curr_batch.pop()} + curr_batch = app.state.curr_batch.pop() + logger.warning( + "API /batch returning prebuilt batch: groups=%s sequences=%s curr_batch_remaining=%s queue_groups=%s", + len(curr_batch), + sum(len(x["tokens"]) for x in curr_batch), + len(app.state.curr_batch), + len(app.state.queue), + ) + return {"batch": curr_batch} else: new_batches = [] # Check if any envs have minimum allocations @@ -394,6 +402,17 @@ async def get_batch(): ) steps_to_take = len(new_batches) if steps_to_take == 0: + now = time.time() + last_empty_log = getattr(app.state, "_last_empty_batch_log", 0.0) + if now - last_empty_log > 30: + logger.warning( + "API /batch no full batch ready: queue_groups=%s queue_sequences=%s curr_batch=%s batch_size=%s", + len(app.state.queue), + sum(len(x.get("tokens", [])) for x in app.state.queue), + len(app.state.curr_batch), + getattr(app.state, "batchsize", -1), + ) + app.state._last_empty_batch_log = now return {"batch": None} app.state.status_dict["step"] += steps_to_take # chunk it @@ -401,9 +420,14 @@ async def get_batch(): app.state.curr_batch.append(batch) curr_batch = app.state.curr_batch.pop() # check length before sending - logger.info( - "Sending batch of %s sequences", + logger.warning( + "API /batch built %s trainer batch(es); returning one with %s groups / %s sequences; curr_batch_remaining=%s queue_groups_remaining=%s new_current_step=%s", + steps_to_take, + len(curr_batch), sum(len(x["tokens"]) for x in curr_batch), + len(app.state.curr_batch), + len(app.state.queue), + app.state.status_dict["step"], ) return {"batch": curr_batch} diff --git a/example_trainer/data.py b/example_trainer/data.py index 770d68fa..bf7f5b19 100644 --- a/example_trainer/data.py +++ b/example_trainer/data.py @@ -366,11 +366,23 @@ def get_data( """ batches = [] _logged_logprob_warning = False + empty_polls = 0 while True: data = get_batch(url=atropos_url) if data["batch"] is not None: + empty_polls = 0 + num_groups = len(data["batch"]) + num_sequences = sum(len(item["tokens"]) for item in data["batch"]) + max_seq_len = max( + max(len(seq) for seq in item["tokens"]) for item in data["batch"] + ) + print( + " [Data] received API batch: " + f"groups={num_groups} sequences={num_sequences} max_seq_len={max_seq_len}", + flush=True, + ) # DEBUG: Check if inference_logprobs exists in the data if not _logged_logprob_warning: has_logprobs = any( @@ -407,6 +419,7 @@ def get_data( _logged_logprob_warning = True # Process and accumulate batches (now includes batched inference logprobs) + print(" [Data] padding / batching API payload...", flush=True) ( token_batches, label_batches, @@ -416,6 +429,12 @@ def get_data( distill_token_id_batches, distill_logprob_batches, ) = pad_data_to_good_offset(data, batch_size, extract_inference_logprobs) + batch_shapes = [tuple(tb.shape) for tb in token_batches] + print( + " [Data] pad_data_to_good_offset done: " + f"micro_batches={len(token_batches)} token_batch_shapes={batch_shapes}", + flush=True, + ) # Include inference logprob batches in the tuple batches.append( @@ -432,7 +451,17 @@ def get_data( elif len(batches) > 0: # Return accumulated batches when no more data + print( + f" [Data] returning {len(batches)} assembled trainer batch tuple(s)", + flush=True, + ) return batches, None else: # Wait for data + empty_polls += 1 + if empty_polls == 1 or empty_polls % 30 == 0: + print( + f" [Data] no batch ready yet (polls_without_data={empty_polls})", + flush=True, + ) time.sleep(1) diff --git a/example_trainer/trainers.py b/example_trainer/trainers.py index cc96cee5..bff1763f 100644 --- a/example_trainer/trainers.py +++ b/example_trainer/trainers.py @@ -317,12 +317,17 @@ def train_shared_vllm(config: TrainingConfig): # Fetch data (with inference logprobs for proper GRPO loss) data_fetch_start = time.time() if len(batches) == 0: + print(" [Trainer] requesting data from Atropos API...", flush=True) batches, _ = get_data( config.batch_size, config.seq_len, config.atropos_url, extract_inference_logprobs=True, # Enable proper GRPO with reference logprobs ) + print( + f" [Trainer] get_data returned {len(batches)} trainer batch tuple(s)", + flush=True, + ) batch_data = batches.pop(0) token_batches, label_batches, advantage_batches, temperature_batches = ( batch_data[:4] @@ -330,6 +335,12 @@ def train_shared_vllm(config: TrainingConfig): inference_logprob_batches = batch_data[4] if len(batch_data) > 4 else None distill_token_id_batches = batch_data[5] if len(batch_data) > 5 else None distill_logprob_batches = batch_data[6] if len(batch_data) > 6 else None + token_shapes = [tuple(tb.shape) for tb in token_batches] + print( + " [Trainer] selected trainer batch: " + f"micro_batches={len(token_batches)} token_batch_shapes={token_shapes}", + flush=True, + ) data_fetch_time = time.time() - data_fetch_start benchmark_stats["data_fetch_times"].append(data_fetch_time)