mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
add managed server to make grabbing logprobs easier w/ tokenized items
This commit is contained in:
parent
312f8859e3
commit
7bf4cfbf80
6 changed files with 1138 additions and 29 deletions
|
|
@ -315,8 +315,10 @@ class MathEnv(BaseEnv):
|
|||
curr_length = int(curr_length * (self.curr_step / self.config.total_steps))
|
||||
curr_length += self.config.start_tok_length
|
||||
thinking_len = min(thinking_len, curr_length)
|
||||
prompt_tokens, out_tokens, out_logprobs, finish_reasons = (
|
||||
await self.server.tokens_and_logprobs_completion(
|
||||
|
||||
# Use managed server for automatic token/logprob tracking
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=user_prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=thinking_len,
|
||||
|
|
@ -324,22 +326,23 @@ class MathEnv(BaseEnv):
|
|||
top_p=1.0,
|
||||
stop=stop_list,
|
||||
)
|
||||
)
|
||||
# print(completions, flush=True)
|
||||
|
||||
# Get tracked sequences with aligned tokens and logprobs
|
||||
state = managed.get_state()
|
||||
nodes = state["nodes"]
|
||||
|
||||
# Extract data from SequenceNodes for scoring
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, (tokens, logprobs, finish_reason) in enumerate(
|
||||
zip(out_tokens, out_logprobs, finish_reasons)
|
||||
):
|
||||
message = self.tokenizer.decode(prompt_tokens + tokens)
|
||||
for i, (choice, node) in enumerate(zip(completion.choices, nodes)):
|
||||
to_score.append(
|
||||
(
|
||||
message,
|
||||
item[1],
|
||||
finish_reason,
|
||||
prompt_tokens,
|
||||
tokens,
|
||||
logprobs,
|
||||
node.full_text, # Complete text (prompt + completion)
|
||||
item[1], # Answer
|
||||
choice.finish_reason, # finish_reason (already a clean string)
|
||||
node.tokens, # all tokens (prompt + completion)
|
||||
node.masked_tokens, # masked tokens (already formatted correctly)
|
||||
node.logprobs, # logprobs (already formatted correctly)
|
||||
)
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
|
|
@ -376,11 +379,13 @@ class MathEnv(BaseEnv):
|
|||
for item in rollout_group_data:
|
||||
scores["overrides"].append(dict())
|
||||
resp = item[0]
|
||||
finish_reason = item[2]
|
||||
user_prompt_tokens = item[3]
|
||||
out_toks = item[4]
|
||||
out_logps = item[5]
|
||||
if item[2]["type"] == "length":
|
||||
finish_reason = item[2] # Now a clean string like "stop" or "length"
|
||||
# ManagedServer already provides properly formatted data
|
||||
tokens = item[3] # Full token sequence
|
||||
masks = item[4] # Masked tokens (already formatted)
|
||||
inf_logp = item[5] # Logprobs (already formatted)
|
||||
|
||||
if finish_reason == "length":
|
||||
reward = False
|
||||
if self.config.mask_too_long_completions:
|
||||
scores["overrides"][-1]["set_advantage_to_zero"] = True
|
||||
|
|
@ -389,11 +394,7 @@ class MathEnv(BaseEnv):
|
|||
reward = await task
|
||||
if reward is None:
|
||||
return None
|
||||
tokens = user_prompt_tokens + out_toks
|
||||
masks = [-100 for _ in range(len(user_prompt_tokens))]
|
||||
masks = masks + out_toks
|
||||
inf_logp = [0 for _ in range(len(user_prompt_tokens))]
|
||||
inf_logp = inf_logp + out_logps
|
||||
|
||||
assert len(inf_logp) == len(
|
||||
masks
|
||||
), f"{len(inf_logp)}, {len(masks)} mismatch"
|
||||
|
|
@ -405,7 +406,7 @@ class MathEnv(BaseEnv):
|
|||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
continue
|
||||
if (item[2] == "length") and (not self.config.mask_too_long_completions):
|
||||
if (finish_reason == "length") and (not self.config.mask_too_long_completions):
|
||||
scores["overrides"][-1]["set_advantage_to_zero"] = True
|
||||
scores["tokens"].append(tokens)
|
||||
scores["masks"].append(masks)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue