mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
change OPD style
This commit is contained in:
parent
527433b5bc
commit
0dcc9156d2
1 changed files with 65 additions and 2 deletions
|
|
@ -231,6 +231,12 @@ class GSM8kEnv(BaseEnv):
|
|||
gold_answer = (
|
||||
"\\boxed{" + item["answer"].split("#")[-1].strip().replace(",", "") + "}"
|
||||
)
|
||||
question_preview = item["question"].replace("\n", " ")[:120]
|
||||
print(
|
||||
f"[GSM8K_DEBUG] collect_start group_size={self.config.group_size} "
|
||||
f"q={question_preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
|
||||
|
|
@ -243,10 +249,24 @@ class GSM8kEnv(BaseEnv):
|
|||
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
print(
|
||||
f"[GSM8K_DEBUG] completion_batch_received choices={len(chat_completions.choices)} "
|
||||
f"nodes={len(nodes)}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, chat_completion in enumerate(chat_completions.choices):
|
||||
response_text = chat_completion.message.content or ""
|
||||
response_preview = response_text.replace("\n", " ")[:220]
|
||||
valid_mask_count = sum(1 for m in nodes[i].masked_tokens if m != -100)
|
||||
print(
|
||||
f"[GSM8K_DEBUG] response_received idx={i} finish={chat_completion.finish_reason} "
|
||||
f"tokens={len(nodes[i].tokens)} valid_masked={valid_mask_count} "
|
||||
f"text={response_preview!r}",
|
||||
flush=True,
|
||||
)
|
||||
messages = (
|
||||
{"role": "system", "content": system_prompt},
|
||||
user_message,
|
||||
|
|
@ -263,6 +283,11 @@ class GSM8kEnv(BaseEnv):
|
|||
}
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
accepted = 0 if to_postprocess is None else len(to_postprocess.get("tokens", []))
|
||||
print(
|
||||
f"[GSM8K_DEBUG] collect_done accepted={accepted} submitted={len(to_score)}",
|
||||
flush=True,
|
||||
)
|
||||
return to_postprocess, to_backlog
|
||||
|
||||
async def score(
|
||||
|
|
@ -278,10 +303,15 @@ class GSM8kEnv(BaseEnv):
|
|||
extraction_mode="first_match",
|
||||
extraction_config=[LatexExtractionConfig()],
|
||||
)
|
||||
print(
|
||||
f"[GSM8K_DEBUG] score_start candidates={len(rollout_group_data)} "
|
||||
f"gold_parsed_len={len(gold_parsed)}",
|
||||
flush=True,
|
||||
)
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
for idx, item in enumerate(rollout_group_data):
|
||||
# print(item[0][-1]["content"])
|
||||
answer_parsed = parse(
|
||||
item["messages"][-1]["content"].split("</think>")[-1],
|
||||
|
|
@ -310,7 +340,19 @@ class GSM8kEnv(BaseEnv):
|
|||
logprobs = item["logprobs"]
|
||||
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
valid_mask_count = len([1 for i in masks if i != -100])
|
||||
print(
|
||||
f"[GSM8K_DEBUG] score_candidate idx={idx} parsed_len={len(answer_parsed)} "
|
||||
f"reward={bool(reward)} valid_masked={valid_mask_count} "
|
||||
f"tokens={len(tokens)}",
|
||||
flush=True,
|
||||
)
|
||||
if valid_mask_count < 10:
|
||||
print(
|
||||
f"[GSM8K_DEBUG] drop_candidate idx={idx} reason=valid_masked_lt_10 "
|
||||
f"value={valid_mask_count}",
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
|
|
@ -323,6 +365,13 @@ class GSM8kEnv(BaseEnv):
|
|||
for score in scores["scores"]:
|
||||
self.percent_correct_buffer.append(max(score, 0))
|
||||
|
||||
if len(scores["scores"]) == 0:
|
||||
print(
|
||||
"[GSM8K_DEBUG] drop_group reason=no_valid_candidates_after_filtering",
|
||||
flush=True,
|
||||
)
|
||||
return None
|
||||
|
||||
# check if all the same
|
||||
# print(scores['scores'])
|
||||
if all([score == 1 for score in scores["scores"]]):
|
||||
|
|
@ -330,6 +379,10 @@ class GSM8kEnv(BaseEnv):
|
|||
token_lengths = [len(token) for token in scores["tokens"]]
|
||||
if max(token_lengths) == 0:
|
||||
# What? But don't want to crash a run so just in case...
|
||||
print(
|
||||
"[GSM8K_DEBUG] drop_group reason=zero_token_length_after_penalty_branch",
|
||||
flush=True,
|
||||
)
|
||||
return None
|
||||
|
||||
# Get max allowed token length from config
|
||||
|
|
@ -353,10 +406,20 @@ class GSM8kEnv(BaseEnv):
|
|||
# Apply linear penalty scaling from 1.0 down to 0.0
|
||||
scores["scores"].append(1.0 - percentage_of_range)
|
||||
if all([scores["scores"][0] == score for score in scores["scores"]]):
|
||||
print(
|
||||
f"[GSM8K_DEBUG] drop_group reason=all_scores_identical scores={scores['scores']}",
|
||||
flush=True,
|
||||
)
|
||||
return None # If all the same, we return None
|
||||
print(
|
||||
f"[GSM8K_DEBUG] score_done accepted={len(scores['scores'])} "
|
||||
f"scores={scores['scores']}",
|
||||
flush=True,
|
||||
)
|
||||
return scores
|
||||
else:
|
||||
# If the gold solution is not parseable, we return None
|
||||
print("[GSM8K_DEBUG] drop_group reason=gold_unparseable", flush=True)
|
||||
return None
|
||||
|
||||
async def get_next_item(self) -> GSM8kRow:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue