match num_max_requests with groupsize

This commit is contained in:
teknium1 2025-05-15 15:57:39 -07:00
parent f814f41893
commit 24c571654e

View file

@ -82,7 +82,7 @@ class InstructionFollowingEnv(BaseEnv):
model_name="NousResearch/DeepHermes-3-Llama-3-8B-Preview",
base_url="http://localhost:9004/v1",
api_key="x",
num_max_requests_at_once=32,
num_max_requests_at_once=16,
num_requests_for_eval=256,
)
]
@ -414,6 +414,8 @@ class InstructionFollowingEnv(BaseEnv):
func_name = test_item["func_name"]
args_for_verifier = test_item["args"]
print(f"DEBUG: Entering rollout_and_score_eval. Prompt: {instruction_prompt_text[:200]}...") # DEBUG
messages = [{"role": "system", "content": system_prompt}]
messages.append({"role": "user", "content": instruction_prompt_text})
@ -421,13 +423,15 @@ class InstructionFollowingEnv(BaseEnv):
messages, add_generation_prompt=True, tokenize=False
)
print(f"DEBUG: Calling self.server.completion in rollout_and_score_eval. Prompt: {prompt_str[:200]}...") # DEBUG
completion = await self.server.completion(
prompt=prompt_str,
n=1,
max_tokens=self.config.max_token_length, # Use config for max_tokens
temperature=0.7, # Temperature for eval, can be 0 for deterministic
temperature=0.2, # Temperature for eval, can be 0 for deterministic
split="eval",
)
print(f"DEBUG: Received completion in rollout_and_score_eval.") # DEBUG
model_response_text = completion.choices[0].text
score_value = await self._get_score_from_verifier(
@ -445,6 +449,7 @@ class InstructionFollowingEnv(BaseEnv):
self.eval_metrics.append(("eval/percent_correct", 0.0))
return
print(f"DEBUG: Starting evaluation. Test set size: {len(self.test)}") # DEBUG
eval_tasks = []
for test_item_dict in self.test: # self.test contains dicts after setup
eval_tasks.append(self.rollout_and_score_eval(test_item_dict))
@ -464,6 +469,7 @@ class InstructionFollowingEnv(BaseEnv):
) -> Tuple[Optional[ScoredDataGroup], List]:
# item = (prompt_messages_tuple, answer_info_dict)
# answer_info_dict = {"func_name": ..., "args": ...}
print(f"DEBUG: Entering collect_trajectories. Item: {item}") # DEBUG
prompt_messages_list = [dict(msg_fset) for msg_fset in item[0]]
answer_info = item[1]
@ -471,12 +477,19 @@ class InstructionFollowingEnv(BaseEnv):
prompt_messages_list, add_generation_prompt=True, tokenize=False
)
completions = await self.server.completion(
prompt=prompt_str,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.8, # Temperature for diverse responses during training rollouts
)
print(f"DEBUG: Calling self.server.completion in collect_trajectories. Prompt: {prompt_str[:200]}...") # DEBUG
try:
completions = await self.server.completion(
prompt=prompt_str,
n=self.config.group_size,
max_tokens=self.config.max_token_length,
temperature=0.8, # Temperature for diverse responses during training rollouts
)
print(f"DEBUG: Received {len(completions.choices)} completions in collect_trajectories.") # DEBUG
except Exception as e:
print(f"ERROR: Exception during self.server.completion in collect_trajectories: {e}") # DEBUG
# Depending on the desired behavior, you might want to return None or raise the exception
return None, []
to_score_list = []
for choice in completions.choices:
@ -489,9 +502,11 @@ class InstructionFollowingEnv(BaseEnv):
if not to_score_list:
return None, []
print(f"DEBUG: Scoring {len(to_score_list)} trajectories in collect_trajectories.") # DEBUG
scored_data = await self.score(to_score_list)
to_backlog = [] # Backlog not currently used but part of signature
print(f"DEBUG: Exiting collect_trajectories. Scored data: {bool(scored_data)}") # DEBUG
return scored_data, to_backlog
def save_checkpoint(self, step, data=None):