diff --git a/environments/gsm8k_server.py b/environments/gsm8k_server.py index 6ae5285b..b38b6549 100644 --- a/environments/gsm8k_server.py +++ b/environments/gsm8k_server.py @@ -231,6 +231,12 @@ class GSM8kEnv(BaseEnv): gold_answer = ( "\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}" ) + question_preview = item["question"].replace("\n", " ")[:120] + print( + f"[GSM8K_DEBUG] collect_start group_size={self.config.group_size} " + f"q={question_preview!r}", + flush=True, + ) async with self.server.managed_server(tokenizer=self.tokenizer) as managed: @@ -243,10 +249,24 @@ class GSM8kEnv(BaseEnv): state = managed.get_state() nodes = state["nodes"] + print( + f"[GSM8K_DEBUG] completion_batch_received choices={len(chat_completions.choices)} " + f"nodes={len(nodes)}", + flush=True, + ) to_score = list() to_backlog = list() for i, chat_completion in enumerate(chat_completions.choices): + response_text = chat_completion.message.content or "" + response_preview = response_text.replace("\n", " ")[:220] + valid_mask_count = sum(1 for m in nodes[i].masked_tokens if m != -100) + print( + f"[GSM8K_DEBUG] response_received idx={i} finish={chat_completion.finish_reason} " + f"tokens={len(nodes[i].tokens)} valid_masked={valid_mask_count} " + f"text={response_preview!r}", + flush=True, + ) messages = ( {"role": "system", "content": system_prompt}, user_message, @@ -263,6 +283,11 @@ class GSM8kEnv(BaseEnv): } ) to_postprocess = await self.score(to_score) + accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", [])) + print( + f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}", + flush=True, + ) return to_postprocess, to_backlog async def score( @@ -278,10 +303,15 @@ class GSM8kEnv(BaseEnv): extraction_mode="first_match", extraction_config=[LatexExtractionConfig()], ) + print( + f"[GSM8K_DEBUG] score_start candidates={len(rollout_group_data)} " + f"gold_parsed_len={len(gold_parsed)}", + flush=True, + ) if len(gold_parsed) != 0: # We require the answer to be provided in correct latex (no malformed operators) random.shuffle(rollout_group_data) - for item in rollout_group_data: + for idx, item in enumerate(rollout_group_data): # print(item[0][-1]["content"]) answer_parsed = parse( item["messages"][-1]["content"].split("")[-1], @@ -310,7 +340,19 @@ class GSM8kEnv(BaseEnv): logprobs = item["logprobs"] # remove obviously bad examples - if len([1 for i in masks if i != -100]) < 10: + valid_mask_count = len([1 for i in masks if i != -100]) + print( + f"[GSM8K_DEBUG] score_candidate idx={idx} parsed_len={len(answer_parsed)} " + f"reward={bool(reward)} valid_masked={valid_mask_count} " + f"tokens={len(tokens)}", + flush=True, + ) + if valid_mask_count < 10: + print( + f"[GSM8K_DEBUG] drop_candidate idx={idx} reason=valid_masked_lt_10 " + f"value={valid_mask_count}", + flush=True, + ) continue scores["tokens"].append(tokens) scores["masks"].append(masks) @@ -323,6 +365,13 @@ class GSM8kEnv(BaseEnv): for score in scores["scores"]: self.percent_correct_buffer.append(max(score, 0)) + if len(scores["scores"]) == 0: + print( + "[GSM8K_DEBUG] drop_group reason=no_valid_candidates_after_filtering", + flush=True, + ) + return None + # check if all the same # print(scores['scores']) if all([score == 1 for score in scores["scores"]]): @@ -330,6 +379,10 @@ class GSM8kEnv(BaseEnv): token_lengths = [len(token) for token in scores["tokens"]] if max(token_lengths) == 0: # What? But don't want to crash a run so just in case... + print( + "[GSM8K_DEBUG] drop_group reason=zero_token_length_after_penalty_branch", + flush=True, + ) return None # Get max allowed token length from config @@ -353,10 +406,20 @@ class GSM8kEnv(BaseEnv): # Apply linear penalty scaling from 1.0 down to 0.0 scores["scores"].append(1.0 - percentage_of_range) if all([scores["scores"][0] == score for score in scores["scores"]]): + print( + f"[GSM8K_DEBUG] drop_group reason=all_scores_identical scores={scores['scores']}", + flush=True, + ) return None # If all the same, we return None + print( + f"[GSM8K_DEBUG] score_done accepted={len(scores['scores'])} " + f"scores={scores['scores']}", + flush=True, + ) return scores else: # If the gold solution is not parseable, we return None + print("[GSM8K_DEBUG] drop_group reason=gold_unparseable", flush=True) return None async def get_next_item(self) -> GSM8kRow: