add managed server to make grabbing logprobs easier w/ tokenized items

This commit is contained in:
dmahan93 2025-10-24 13:09:46 -07:00
parent 312f8859e3
commit 7bf4cfbf80
6 changed files with 1138 additions and 29 deletions

View file

@ -315,8 +315,10 @@ class MathEnv(BaseEnv):
curr_length = int(curr_length * (self.curr_step / self.config.total_steps))
curr_length += self.config.start_tok_length
thinking_len = min(thinking_len, curr_length)
prompt_tokens, out_tokens, out_logprobs, finish_reasons = (
await self.server.tokens_and_logprobs_completion(
# Use managed server for automatic token/logprob tracking
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
completion = await managed.completion(
prompt=user_prompt,
n=self.config.group_size,
max_tokens=thinking_len,
@ -324,22 +326,23 @@ class MathEnv(BaseEnv):
top_p=1.0,
stop=stop_list,
)
)
# print(completions, flush=True)
# Get tracked sequences with aligned tokens and logprobs
state = managed.get_state()
nodes = state["nodes"]
# Extract data from SequenceNodes for scoring
to_score = list()
to_backlog = list()
for i, (tokens, logprobs, finish_reason) in enumerate(
zip(out_tokens, out_logprobs, finish_reasons)
):
message = self.tokenizer.decode(prompt_tokens + tokens)
for i, (choice, node) in enumerate(zip(completion.choices, nodes)):
to_score.append(
(
message,
item[1],
finish_reason,
prompt_tokens,
tokens,
logprobs,
node.full_text, # Complete text (prompt + completion)
item[1], # Answer
choice.finish_reason, # finish_reason (already a clean string)
node.tokens, # all tokens (prompt + completion)
node.masked_tokens, # masked tokens (already formatted correctly)
node.logprobs, # logprobs (already formatted correctly)
)
)
to_postprocess = await self.score(to_score)
@ -376,11 +379,13 @@ class MathEnv(BaseEnv):
for item in rollout_group_data:
scores["overrides"].append(dict())
resp = item[0]
finish_reason = item[2]
user_prompt_tokens = item[3]
out_toks = item[4]
out_logps = item[5]
if item[2]["type"] == "length":
finish_reason = item[2] # Now a clean string like "stop" or "length"
# ManagedServer already provides properly formatted data
tokens = item[3] # Full token sequence
masks = item[4] # Masked tokens (already formatted)
inf_logp = item[5] # Logprobs (already formatted)
if finish_reason == "length":
reward = False
if self.config.mask_too_long_completions:
scores["overrides"][-1]["set_advantage_to_zero"] = True
@ -389,11 +394,7 @@ class MathEnv(BaseEnv):
reward = await task
if reward is None:
return None
tokens = user_prompt_tokens + out_toks
masks = [-100 for _ in range(len(user_prompt_tokens))]
masks = masks + out_toks
inf_logp = [0 for _ in range(len(user_prompt_tokens))]
inf_logp = inf_logp + out_logps
assert len(inf_logp) == len(
masks
), f"{len(inf_logp)}, {len(masks)} mismatch"
@ -405,7 +406,7 @@ class MathEnv(BaseEnv):
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
continue
if (item[2] == "length") and (not self.config.mask_too_long_completions):
if (finish_reason == "length") and (not self.config.mask_too_long_completions):
scores["overrides"][-1]["set_advantage_to_zero"] = True
scores["tokens"].append(tokens)
scores["masks"].append(masks)