mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-23 16:54:56 +00:00
convert instruct following env to use managedserver
This commit is contained in:
parent
a4d81e36d1
commit
46f05673aa
1 changed files with 22 additions and 23 deletions
|
|
@ -19,7 +19,6 @@ from atroposlib.envs.base import (
|
|||
Item,
|
||||
ScoredDataGroup,
|
||||
)
|
||||
from atroposlib.utils.tokenize_for_trainer import tokenize_for_trainer
|
||||
|
||||
# System prompt can be reused or adapted for instruction following tasks
|
||||
system_prompt = (
|
||||
|
|
@ -592,13 +591,14 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
completion = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length, # Use config for max_tokens
|
||||
temperature=0.2, # Temperature for eval, can be 0 for deterministic
|
||||
split="eval",
|
||||
)
|
||||
async with self.server.managed_server(tokenizer=self.tokenizer) as managed:
|
||||
completion = await managed.completion(
|
||||
prompt=prompt_str,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length, # Use config for max_tokens
|
||||
temperature=0.2, # Temperature for eval, can be 0 for deterministic
|
||||
split="eval",
|
||||
)
|
||||
|
||||
model_response_text = completion.choices[0].text
|
||||
score_value = await self._get_score_from_verifier(
|
||||
|
|
@ -746,15 +746,16 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
await super().close()
|
||||
|
||||
async def score(
|
||||
self, rollout_group_data: List[Tuple[tuple, Dict]]
|
||||
self, rollout_group_data: List[Dict]
|
||||
) -> Optional[ScoredDataGroup]:
|
||||
# rollout_group_data is a list of (trajectory_messages_tuple, answer_info_dict)
|
||||
# rollout_group_data is a list of dicts with messages, answer_info, tokens, masks, logprobs
|
||||
# answer_info_dict = {"func_name": ..., "args": ...}
|
||||
|
||||
scores_container = ScoredDataGroup()
|
||||
scores_container["tokens"] = list()
|
||||
scores_container["masks"] = list()
|
||||
scores_container["scores"] = list()
|
||||
scores_container["inference_logprobs"] = list()
|
||||
|
||||
if not rollout_group_data:
|
||||
return None
|
||||
|
|
@ -770,8 +771,8 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
failed_rollouts_for_this_group = []
|
||||
|
||||
for trajectory_item in rollout_group_data:
|
||||
full_trajectory_messages = trajectory_item[0]
|
||||
answer_info = trajectory_item[1] # {"func_name": ..., "args": ...}
|
||||
full_trajectory_messages = trajectory_item["messages"]
|
||||
answer_info = trajectory_item["answer_info"] # {"func_name": ..., "args": ...}
|
||||
|
||||
model_response_text = full_trajectory_messages[-1]["content"]
|
||||
func_name = answer_info["func_name"]
|
||||
|
|
@ -803,12 +804,9 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
elif self.config.dump_failed_rollouts and reward == 0:
|
||||
failed_rollouts_for_this_group.append(rollout_dict)
|
||||
|
||||
# Tokenize the conversation for PPO training
|
||||
# Ensure full_trajectory_messages is a list of dicts
|
||||
list_of_dicts_trajectory = [dict(msg) for msg in full_trajectory_messages]
|
||||
out_dict = tokenize_for_trainer(self.tokenizer, list_of_dicts_trajectory)
|
||||
tokens = out_dict["tokens"]
|
||||
masks = out_dict["masks"]
|
||||
tokens = trajectory_item["tokens"]
|
||||
masks = trajectory_item["masks"]
|
||||
logprobs = trajectory_item["logprobs"]
|
||||
|
||||
# Filter out examples with insufficient context (too short)
|
||||
if (
|
||||
|
|
@ -818,6 +816,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
|
||||
scores_container["tokens"].append(tokens)
|
||||
scores_container["masks"].append(masks)
|
||||
scores_container["inference_logprobs"].append(logprobs)
|
||||
scores_container["scores"].append(reward)
|
||||
|
||||
# Stop if we have enough examples for the group
|
||||
|
|
@ -832,7 +831,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
if current_scores:
|
||||
average_score = sum(current_scores) / len(current_scores)
|
||||
# Get task info from the first rollout's answer_info
|
||||
answer_info = rollout_group_data[0][1] if rollout_group_data else {}
|
||||
answer_info = rollout_group_data[0]["answer_info"] if rollout_group_data else {}
|
||||
func_name = answer_info.get("func_name", "unknown_task")
|
||||
|
||||
# Check if group is too easy for training (but still allow data dumping)
|
||||
|
|
@ -849,7 +848,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
<= self.config.max_group_average_for_training + 0.1
|
||||
): # Small buffer for data collection
|
||||
# Extract item info for the group - get from first rollout's answer_info
|
||||
answer_info = rollout_group_data[0][1]
|
||||
answer_info = rollout_group_data[0]["answer_info"]
|
||||
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
|
||||
|
||||
group_data_to_save = {
|
||||
|
|
@ -863,7 +862,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
|
||||
if failed_rollouts_for_this_group:
|
||||
# Extract item info for the failed group
|
||||
answer_info = rollout_group_data[0][1]
|
||||
answer_info = rollout_group_data[0]["answer_info"]
|
||||
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
|
||||
|
||||
failed_group_data_to_save = {
|
||||
|
|
@ -906,7 +905,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
# Create group data structure and add to buffers for data dumping (for training groups)
|
||||
if rollouts_for_this_group:
|
||||
# Extract item info for the group - get from first rollout's answer_info
|
||||
answer_info = rollout_group_data[0][1]
|
||||
answer_info = rollout_group_data[0]["answer_info"]
|
||||
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
|
||||
|
||||
group_data_to_save = {
|
||||
|
|
@ -924,7 +923,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
|
||||
if failed_rollouts_for_this_group:
|
||||
# Extract item info for the failed group
|
||||
answer_info = rollout_group_data[0][1]
|
||||
answer_info = rollout_group_data[0]["answer_info"]
|
||||
item_id = f"allenai_RLVR-IFeval_train_item_{answer_info.get('func_name', 'unknown')}_{hash(str(answer_info)) % 100000}" # noqa
|
||||
|
||||
failed_group_data_to_save = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue