mirror of
https://github.com/NousResearch/atropos.git
synced 2026-04-19 12:57:58 +00:00
match num_max_requests with groupsize
This commit is contained in:
parent
f814f41893
commit
24c571654e
1 changed files with 23 additions and 8 deletions
|
|
@ -82,7 +82,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
|
||||
base_url="http://localhost:9004/v1",
|
||||
api_key="x",
|
||||
num_max_requests_at_once=32,
|
||||
num_max_requests_at_once=16,
|
||||
num_requests_for_eval=256,
|
||||
)
|
||||
]
|
||||
|
|
@ -414,6 +414,8 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
func_name = test_item["func_name"]
|
||||
args_for_verifier = test_item["args"]
|
||||
|
||||
print(f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}...") # DEBUG
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
messages.append({"role": "user", "content": instruction_prompt_text})
|
||||
|
||||
|
|
@ -421,13 +423,15 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
messages, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
print(f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}...") # DEBUG
|
||||
completion = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
n=1,
|
||||
max_tokens=self.config.max_token_length, # Use config for max_tokens
|
||||
temperature=0.7, # Temperature for eval, can be 0 for deterministic
|
||||
temperature=0.2, # Temperature for eval, can be 0 for deterministic
|
||||
split="eval",
|
||||
)
|
||||
print(f"DEBUG: Received completion in rollout_and_score_eval.") # DEBUG
|
||||
|
||||
model_response_text = completion.choices[0].text
|
||||
score_value = await self._get_score_from_verifier(
|
||||
|
|
@ -445,6 +449,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
self.eval_metrics.append(("eval/percent_correct", 0.0))
|
||||
return
|
||||
|
||||
print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG
|
||||
eval_tasks = []
|
||||
for test_item_dict in self.test: # self.test contains dicts after setup
|
||||
eval_tasks.append(self.rollout_and_score_eval(test_item_dict))
|
||||
|
|
@ -464,6 +469,7 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
) -> Tuple[Optional[ScoredDataGroup], List]:
|
||||
# item = (prompt_messages_tuple, answer_info_dict)
|
||||
# answer_info_dict = {"func_name": ..., "args": ...}
|
||||
print(f"DEBUG: Entering collect_trajectories. Item: {item}") # DEBUG
|
||||
prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]]
|
||||
answer_info = item[1]
|
||||
|
||||
|
|
@ -471,12 +477,19 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
prompt_messages_list, add_generation_prompt=True, tokenize=False
|
||||
)
|
||||
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.8, # Temperature for diverse responses during training rollouts
|
||||
)
|
||||
print(f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}...") # DEBUG
|
||||
try:
|
||||
completions = await self.server.completion(
|
||||
prompt=prompt_str,
|
||||
n=self.config.group_size,
|
||||
max_tokens=self.config.max_token_length,
|
||||
temperature=0.8, # Temperature for diverse responses during training rollouts
|
||||
)
|
||||
print(f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories.") # DEBUG
|
||||
except Exception as e:
|
||||
print(f"ERROR: Exception during self.server.completion in collect_trajectories: {e}") # DEBUG
|
||||
# Depending on the desired behavior, you might want to return None or raise the exception
|
||||
return None, []
|
||||
|
||||
to_score_list = []
|
||||
for choice in completions.choices:
|
||||
|
|
@ -489,9 +502,11 @@ class InstructionFollowingEnv(BaseEnv):
|
|||
if not to_score_list:
|
||||
return None, []
|
||||
|
||||
print(f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories.") # DEBUG
|
||||
scored_data = await self.score(to_score_list)
|
||||
to_backlog = [] # Backlog not currently used but part of signature
|
||||
|
||||
print(f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}") # DEBUG
|
||||
return scored_data, to_backlog
|
||||
|
||||
def save_checkpoint(self, step, data=None):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue