change OPD style

This commit is contained in:
Jai Suphavadeeprasit 2026-02-19 19:19:23 -05:00
parent 527433b5bc
commit 0dcc9156d2

View file

@ -231,6 +231,12 @@ class GSM8kEnv(BaseEnv):
gold_answer = (
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
)
question_preview = item["question"].replace("\n", " ")[:120]
print(
f"[GSM8K_DEBUG] collect_start group_size={self.config.group_size} "
f"q={question_preview!r}",
flush=True,
)
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
@ -243,10 +249,24 @@ class GSM8kEnv(BaseEnv):
state = managed.get_state()
nodes = state["nodes"]
print(
f"[GSM8K_DEBUG] completion_batch_received choices={len(chat_completions.choices)} "
f"nodes={len(nodes)}",
flush=True,
)
to_score = list()
to_backlog = list()
for i, chat_completion in enumerate(chat_completions.choices):
response_text = chat_completion.message.content or ""
response_preview = response_text.replace("\n", " ")[:220]
valid_mask_count = sum(1 for m in nodes[i].masked_tokens if m != -100)
print(
f"[GSM8K_DEBUG] response_received idx={i} finish={chat_completion.finish_reason} "
f"tokens={len(nodes[i].tokens)} valid_masked={valid_mask_count} "
f"text={response_preview!r}",
flush=True,
)
messages = (
{"role": "system", "content": system_prompt},
user_message,
@ -263,6 +283,11 @@ class GSM8kEnv(BaseEnv):
}
)
to_postprocess = await self.score(to_score)
accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", []))
print(
f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}",
flush=True,
)
return to_postprocess, to_backlog
async def score(
@ -278,10 +303,15 @@ class GSM8kEnv(BaseEnv):
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
print(
f"[GSM8K_DEBUG] score_start candidates={len(rollout_group_data)} "
f"gold_parsed_len={len(gold_parsed)}",
flush=True,
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
random.shuffle(rollout_group_data)
for item in rollout_group_data:
for idx, item in enumerate(rollout_group_data):
# print(item[0][-1]["content"])
answer_parsed = parse(
item["messages"][-1]["content"].split("</think>")[-1],
@ -310,7 +340,19 @@ class GSM8kEnv(BaseEnv):
logprobs = item["logprobs"]
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
valid_mask_count = len([1 for i in masks if i != -100])
print(
f"[GSM8K_DEBUG] score_candidate idx={idx} parsed_len={len(answer_parsed)} "
f"reward={bool(reward)} valid_masked={valid_mask_count} "
f"tokens={len(tokens)}",
flush=True,
)
if valid_mask_count < 10:
print(
f"[GSM8K_DEBUG] drop_candidate idx={idx} reason=valid_masked_lt_10 "
f"value={valid_mask_count}",
flush=True,
)
continue
scores["tokens"].append(tokens)
scores["masks"].append(masks)
@ -323,6 +365,13 @@ class GSM8kEnv(BaseEnv):
for score in scores["scores"]:
self.percent_correct_buffer.append(max(score, 0))
if len(scores["scores"]) == 0:
print(
"[GSM8K_DEBUG] drop_group reason=no_valid_candidates_after_filtering",
flush=True,
)
return None
# check if all the same
# print(scores['scores'])
if all([score == 1 for score in scores["scores"]]):
@ -330,6 +379,10 @@ class GSM8kEnv(BaseEnv):
token_lengths = [len(token) for token in scores["tokens"]]
if max(token_lengths) == 0:
# What? But don't want to crash a run so just in case...
print(
"[GSM8K_DEBUG] drop_group reason=zero_token_length_after_penalty_branch",
flush=True,
)
return None
# Get max allowed token length from config
@ -353,10 +406,20 @@ class GSM8kEnv(BaseEnv):
# Apply linear penalty scaling from 1.0 down to 0.0
scores["scores"].append(1.0 - percentage_of_range)
if all([scores["scores"][0] == score for score in scores["scores"]]):
print(
f"[GSM8K_DEBUG] drop_group reason=all_scores_identical scores={scores['scores']}",
flush=True,
)
return None # If all the same, we return None
print(
f"[GSM8K_DEBUG] score_done accepted={len(scores['scores'])} "
f"scores={scores['scores']}",
flush=True,
)
return scores
else:
# If the gold solution is not parseable, we return None
print("[GSM8K_DEBUG] drop_group reason=gold_unparseable", flush=True)
return None
async def get_next_item(self) -> GSM8kRow: