add sglang specific token level logprob handling and server manager/baseline logprob/token fn

This commit is contained in:
Dakota 2025-10-16 12:38:03 -05:00
parent 4862e9972f
commit c36ec29656
4 changed files with 512 additions and 37 deletions

View file

@ -18,11 +18,11 @@ from pydantic import Field
from tqdm.asyncio import tqdm_asyncio
from atroposlib.envs.base import (
ServerBaseline,
BaseEnv,
BaseEnvConfig,
EvalHandlingEnum,
ScoredDataGroup,
ServerBaseline,
)
prompt_format = (
@ -48,6 +48,9 @@ class RSConfig(BaseEnvConfig):
percent_length_penalty: float = Field(
0.0, description="The percentage of items to have length penalty"
)
start_tok_length: int = Field(
8192, description="The starting length of the token length, scaled linearly to the max_token_length"
)
def score_answer(gold, resp) -> Optional[bool]:
@ -136,7 +139,7 @@ class MathEnv(BaseEnv):
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
env_config = RSConfig(
tokenizer_name="Qwen/Qwen2.5-7B",
group_size=8,
group_size=16,
use_wandb=True,
rollout_server_url="http://localhost:8000",
total_steps=1000,
@ -146,10 +149,12 @@ class MathEnv(BaseEnv):
wandb_name="math",
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
eval_limit_ratio=0.1,
max_num_workers_per_node=24
)
server_configs = ServerBaseline(
model_name="default",
model_name="Qwen/Qwen2.5-7B",
num_requests_for_eval=256, # since evaling only on one...
server_type="sglang"
)
return env_config, server_configs
@ -161,7 +166,7 @@ class MathEnv(BaseEnv):
wandb_metrics["train/pass_at_groupsize"] = sum(
self.pass_at_groupsize
) / len(self.pass_at_groupsize)
self.pass_at_8 = list()
self.pass_at_groupsize = list()
if len(self.percent_correct_buffer) > 0:
wandb_metrics["train/percent_correct"] = sum(
self.percent_correct_buffer
@ -189,6 +194,10 @@ class MathEnv(BaseEnv):
table.add_data(group[0], group[1], group[2], group[3])
wandb_metrics["train/normal_rollouts"] = table
wandb_metrics["train/iter"] = self.iter
curr_length = self.config.max_token_length - self.config.start_tok_length
curr_length = int(curr_length * (self.curr_step / self.config.total_steps))
curr_length += self.config.start_tok_length
wandb_metrics["train/curr_token_length"] = curr_length
for item in self.eval_metrics:
wandb_metrics[item[0]] = item[1]
self.eval_metrics = list()
@ -231,7 +240,10 @@ class MathEnv(BaseEnv):
name,
)
)
for name, t_dataset in zip(["olympiad"], [olympiad_test_data]):
for name, t_dataset in zip(
['olympiad'],
[olympiad_test_data]
):
for item in t_dataset:
self.test.append(
(
@ -262,7 +274,7 @@ class MathEnv(BaseEnv):
"<answer>" in completion.choices[0].text
):
# assume it stopped on </answer>
resp = resp + " </answer>"
resp = resp + "</answer>"
task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp)
reward = await task
if reward is None:
@ -301,30 +313,31 @@ class MathEnv(BaseEnv):
prompt=problem_format.format(problem=item[0])
)
thinking_len = thinking_len - len(self.tokenizer.encode(user_prompt))
completions = await self.server.completion(
curr_length = self.config.max_token_length - self.config.start_tok_length
curr_length = int(curr_length * (self.curr_step / self.config.total_steps))
curr_length += self.config.start_tok_length
thinking_len = min(thinking_len, curr_length)
prompt_tokens, out_tokens, out_logprobs, finish_reasons = await self.server.tokens_and_logprobs_completion(
prompt=user_prompt,
n=self.config.group_size,
max_tokens=thinking_len,
temperature=1.0,
top_p=0.95,
top_p=1.0,
stop=stop_list,
)
# print(completions, flush=True)
to_score = list()
to_backlog = list()
for i, completion in enumerate(completions.choices):
message = user_prompt + completion.text
if completion.finish_reason == "stop":
if ("</answer>" not in completion.text) and (
"<answer>" in completion.text
):
# assume it stopped on </answer>
message = message + " </answer>"
for i, (tokens, logprobs, finish_reason) in enumerate(zip(out_tokens, out_logprobs, finish_reasons)):
message = self.tokenizer.decode(prompt_tokens + tokens)
to_score.append(
(
message,
item[1],
completion.finish_reason,
user_prompt,
finish_reason,
prompt_tokens,
tokens,
logprobs,
)
)
to_postprocess = await self.score(to_score)
@ -354,13 +367,18 @@ class MathEnv(BaseEnv):
scores["scores"] = list()
scores["overrides"] = list()
scores["messages"] = list()
scores["inference_logprobs"] = list()
gold = rollout_group_data[0][1]
loop = asyncio.get_event_loop()
random.shuffle(rollout_group_data)
for item in rollout_group_data:
resp = item[0]
scores["overrides"].append(dict())
if item[2] == "length":
resp = item[0]
finish_reason = item[2]
user_prompt_tokens = item[3]
out_toks = item[4]
out_logps = item[5]
if item[2]['type'] == "length":
reward = False
if self.config.mask_too_long_completions:
scores["overrides"][-1]["set_advantage_to_zero"] = True
@ -369,23 +387,16 @@ class MathEnv(BaseEnv):
reward = await task
if reward is None:
return None
tokens = self.tokenizer.encode(resp)
user_prompt_tokens = self.tokenizer.encode(item[3])
if user_prompt_tokens[-1] == self.tokenizer.eos_token_id:
user_prompt_tokens = user_prompt_tokens[:-1]
assert all(
[
i == j
for i, j in zip(
user_prompt_tokens, tokens[: len(user_prompt_tokens)]
)
]
)
tokens = user_prompt_tokens + out_toks
masks = [-100 for _ in range(len(user_prompt_tokens))]
masks = masks + tokens[len(user_prompt_tokens) :]
masks = masks + out_toks
inf_logp = [0 for _ in range(len(user_prompt_tokens))]
inf_logp = inf_logp + out_logps
assert len(inf_logp) == len(masks), f"{len(inf_logp)}, {len(masks)} mismatch"
user_prompt = resp.split("<think>")[0]
messages = [
{"role": "user", "content": item[3]},
{"role": "assistant", "content": resp[len(item[3]) :]},
{"role": "user", "content": user_prompt},
{"role": "assistant", "content": resp[len(user_prompt) :]},
]
# remove obviously bad examples
if len([1 for i in masks if i != -100]) < 10:
@ -396,6 +407,7 @@ class MathEnv(BaseEnv):
scores["masks"].append(masks)
scores["scores"].append(1.0 if reward else -1.0)
scores["messages"].append(messages)
scores["inference_logprobs"].append(inf_logp)
if len(scores["tokens"]) >= self.config.group_size:
break
if any([score == 1.0 for score in scores["scores"]]):
@ -412,7 +424,7 @@ class MathEnv(BaseEnv):
)
# check if all the same
# print(scores['scores'])
# Fill in the correct/incorrect lenses after so we're only looking at actual training data
# Fill in the correct/incorrect lens after so we're only looking at actual training data
self.correct_answer_len.extend(
[
len(scores["tokens"][i])