mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-24 17:04:55 +00:00
add sglang specific token level logprob handling and server manager/baseline logprob/token fn
This commit is contained in:
parent
4862e9972f
commit
c36ec29656
4 changed files with 512 additions and 37 deletions
|
|
@ -18,11 +18,11 @@ from pydantic import Field
|
|||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from atroposlib.envs.base import (
|
||||
ServerBaseline,
|
||||
BaseEnv,
|
||||
BaseEnvConfig,
|
||||
EvalHandlingEnum,
|
||||
ScoredDataGroup,
|
||||
ServerBaseline,
|
||||
)
|
||||
|
||||
prompt_format = (
|
||||
|
|
@ -48,6 +48,9 @@ class RSConfig(BaseEnvConfig):
|
|||
percent_length_penalty: float = Field(
|
||||
0.0, description="The percentage of items to have length penalty"
|
||||
)
|
||||
start_tok_length: int = Field(
|
||||
8192, description="The starting length of the token length, scaled linearly to the max_token_length"
|
||||
)
|
||||
|
||||
|
||||
def score_answer(gold, resp) -> Optional[bool]:
|
||||
|
|
@ -136,7 +139,7 @@ class MathEnv(BaseEnv):
|
|||
def config_init(cls) -> Tuple[RSConfig, ServerBaseline]:
|
||||
env_config = RSConfig(
|
||||
tokenizer_name="Qwen/Qwen2.5-7B",
|
||||
group_size=8,
|
||||
group_size=16,
|
||||
use_wandb=True,
|
||||
rollout_server_url="http://localhost:8000",
|
||||
total_steps=1000,
|
||||
|
|
@ -146,10 +149,12 @@ class MathEnv(BaseEnv):
|
|||
wandb_name="math",
|
||||
eval_handling=EvalHandlingEnum.LIMIT_TRAIN,
|
||||
eval_limit_ratio=0.1,
|
||||
max_num_workers_per_node=24
|
||||
)
|
||||
server_configs = ServerBaseline(
|
||||
model_name="default",
|
||||
model_name="Qwen/Qwen2.5-7B",
|
||||
num_requests_for_eval=256, # since evaling only on one...
|
||||
server_type="sglang"
|
||||
)
|
||||
|
||||
return env_config, server_configs
|
||||
|
|
@ -161,7 +166,7 @@ class MathEnv(BaseEnv):
|
|||
wandb_metrics["train/pass_at_groupsize"] = sum(
|
||||
self.pass_at_groupsize
|
||||
) / len(self.pass_at_groupsize)
|
||||
self.pass_at_8 = list()
|
||||
self.pass_at_groupsize = list()
|
||||
if len(self.percent_correct_buffer) > 0:
|
||||
wandb_metrics["train/percent_correct"] = sum(
|
||||
self.percent_correct_buffer
|
||||
|
|
@ -189,6 +194,10 @@ class MathEnv(BaseEnv):
|
|||
table.add_data(group[0], group[1], group[2], group[3])
|
||||
wandb_metrics["train/normal_rollouts"] = table
|
||||
wandb_metrics["train/iter"] = self.iter
|
||||
curr_length = self.config.max_token_length - self.config.start_tok_length
|
||||
curr_length = int(curr_length * (self.curr_step / self.config.total_steps))
|
||||
curr_length += self.config.start_tok_length
|
||||
wandb_metrics["train/curr_token_length"] = curr_length
|
||||
for item in self.eval_metrics:
|
||||
wandb_metrics[item[0]] = item[1]
|
||||
self.eval_metrics = list()
|
||||
|
|
@ -231,7 +240,10 @@ class MathEnv(BaseEnv):
|
|||
name,
|
||||
)
|
||||
)
|
||||
for name, t_dataset in zip(["olympiad"], [olympiad_test_data]):
|
||||
for name, t_dataset in zip(
|
||||
['olympiad'],
|
||||
[olympiad_test_data]
|
||||
):
|
||||
for item in t_dataset:
|
||||
self.test.append(
|
||||
(
|
||||
|
|
@ -262,7 +274,7 @@ class MathEnv(BaseEnv):
|
|||
"<answer>" in completion.choices[0].text
|
||||
):
|
||||
# assume it stopped on </answer>
|
||||
resp = resp + " </answer>"
|
||||
resp = resp + "</answer>"
|
||||
task = loop.run_in_executor(self.mp_executor, score_answer, gold, resp)
|
||||
reward = await task
|
||||
if reward is None:
|
||||
|
|
@ -301,30 +313,31 @@ class MathEnv(BaseEnv):
|
|||
prompt=problem_format.format(problem=item[0])
|
||||
)
|
||||
thinking_len = thinking_len - len(self.tokenizer.encode(user_prompt))
|
||||
completions = await self.server.completion(
|
||||
curr_length = self.config.max_token_length - self.config.start_tok_length
|
||||
curr_length = int(curr_length * (self.curr_step / self.config.total_steps))
|
||||
curr_length += self.config.start_tok_length
|
||||
thinking_len = min(thinking_len, curr_length)
|
||||
prompt_tokens, out_tokens, out_logprobs, finish_reasons = await self.server.tokens_and_logprobs_completion(
|
||||
prompt=user_prompt,
|
||||
n=self.config.group_size,
|
||||
max_tokens=thinking_len,
|
||||
temperature=1.0,
|
||||
top_p=0.95,
|
||||
top_p=1.0,
|
||||
stop=stop_list,
|
||||
)
|
||||
# print(completions, flush=True)
|
||||
to_score = list()
|
||||
to_backlog = list()
|
||||
for i, completion in enumerate(completions.choices):
|
||||
message = user_prompt + completion.text
|
||||
if completion.finish_reason == "stop":
|
||||
if ("</answer>" not in completion.text) and (
|
||||
"<answer>" in completion.text
|
||||
):
|
||||
# assume it stopped on </answer>
|
||||
message = message + " </answer>"
|
||||
for i, (tokens, logprobs, finish_reason) in enumerate(zip(out_tokens, out_logprobs, finish_reasons)):
|
||||
message = self.tokenizer.decode(prompt_tokens + tokens)
|
||||
to_score.append(
|
||||
(
|
||||
message,
|
||||
item[1],
|
||||
completion.finish_reason,
|
||||
user_prompt,
|
||||
finish_reason,
|
||||
prompt_tokens,
|
||||
tokens,
|
||||
logprobs,
|
||||
)
|
||||
)
|
||||
to_postprocess = await self.score(to_score)
|
||||
|
|
@ -354,13 +367,18 @@ class MathEnv(BaseEnv):
|
|||
scores["scores"] = list()
|
||||
scores["overrides"] = list()
|
||||
scores["messages"] = list()
|
||||
scores["inference_logprobs"] = list()
|
||||
gold = rollout_group_data[0][1]
|
||||
loop = asyncio.get_event_loop()
|
||||
random.shuffle(rollout_group_data)
|
||||
for item in rollout_group_data:
|
||||
resp = item[0]
|
||||
scores["overrides"].append(dict())
|
||||
if item[2] == "length":
|
||||
resp = item[0]
|
||||
finish_reason = item[2]
|
||||
user_prompt_tokens = item[3]
|
||||
out_toks = item[4]
|
||||
out_logps = item[5]
|
||||
if item[2]['type'] == "length":
|
||||
reward = False
|
||||
if self.config.mask_too_long_completions:
|
||||
scores["overrides"][-1]["set_advantage_to_zero"] = True
|
||||
|
|
@ -369,23 +387,16 @@ class MathEnv(BaseEnv):
|
|||
reward = await task
|
||||
if reward is None:
|
||||
return None
|
||||
tokens = self.tokenizer.encode(resp)
|
||||
user_prompt_tokens = self.tokenizer.encode(item[3])
|
||||
if user_prompt_tokens[-1] == self.tokenizer.eos_token_id:
|
||||
user_prompt_tokens = user_prompt_tokens[:-1]
|
||||
assert all(
|
||||
[
|
||||
i == j
|
||||
for i, j in zip(
|
||||
user_prompt_tokens, tokens[: len(user_prompt_tokens)]
|
||||
)
|
||||
]
|
||||
)
|
||||
tokens = user_prompt_tokens + out_toks
|
||||
masks = [-100 for _ in range(len(user_prompt_tokens))]
|
||||
masks = masks + tokens[len(user_prompt_tokens) :]
|
||||
masks = masks + out_toks
|
||||
inf_logp = [0 for _ in range(len(user_prompt_tokens))]
|
||||
inf_logp = inf_logp + out_logps
|
||||
assert len(inf_logp) == len(masks), f"{len(inf_logp)}, {len(masks)} mismatch"
|
||||
user_prompt = resp.split("<think>")[0]
|
||||
messages = [
|
||||
{"role": "user", "content": item[3]},
|
||||
{"role": "assistant", "content": resp[len(item[3]) :]},
|
||||
{"role": "user", "content": user_prompt},
|
||||
{"role": "assistant", "content": resp[len(user_prompt) :]},
|
||||
]
|
||||
# remove obviously bad examples
|
||||
if len([1 for i in masks if i != -100]) < 10:
|
||||
|
|
@ -396,6 +407,7 @@ class MathEnv(BaseEnv):
|
|||
scores["masks"].append(masks)
|
||||
scores["scores"].append(1.0 if reward else -1.0)
|
||||
scores["messages"].append(messages)
|
||||
scores["inference_logprobs"].append(inf_logp)
|
||||
if len(scores["tokens"]) >= self.config.group_size:
|
||||
break
|
||||
if any([score == 1.0 for score in scores["scores"]]):
|
||||
|
|
@ -412,7 +424,7 @@ class MathEnv(BaseEnv):
|
|||
)
|
||||
# check if all the same
|
||||
# print(scores['scores'])
|
||||
# Fill in the correct/incorrect lenses after so we're only looking at actual training data
|
||||
# Fill in the correct/incorrect lens after so we're only looking at actual training data
|
||||
self.correct_answer_len.extend(
|
||||
[
|
||||
len(scores["tokens"][i])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue